1"""
2The Slicer classes.
3
4The main purpose of these classes is to have auto adjust of axes size to
5the data with different layout of cuts.
6"""
7
8import collections.abc
9import numbers
10from distutils.version import LooseVersion
11
12import matplotlib
13import matplotlib.pyplot as plt
14import numpy as np
15import warnings
16from matplotlib import cm as mpl_cm
17from matplotlib import (colors,
18                        lines,
19                        transforms,
20                        )
21from matplotlib.colorbar import ColorbarBase
22from matplotlib.font_manager import FontProperties
23from matplotlib.patches import FancyArrow
24from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
25from scipy import sparse, stats
26
27from . import cm, glass_brain
28from .edge_detect import _edge_map
29from .find_cuts import find_xyz_cut_coords, find_cut_slices
30from .. import _utils
31from ..image import new_img_like
32from ..image.resampling import (get_bounds, reorder_img, coord_transform,
33                                get_mask_bounds)
34from nilearn.image import get_data
35
36
37###############################################################################
38# class BaseAxes
39###############################################################################
40
41class BaseAxes(object):
42    """ An MPL axis-like object that displays a 2D view of 3D volumes
43    """
44
45    def __init__(self, ax, direction, coord):
46        """ An MPL axis-like object that displays a cut of 3D volumes
47
48        Parameters
49        ----------
50        ax : A MPL axes instance
51            The axes in which the plots will be drawn.
52
53        direction : {'x', 'y', 'z'}
54            The directions of the view.
55
56        coord : float
57            The coordinate along the direction of the cut.
58
59        """
60        self.ax = ax
61        self.direction = direction
62        self.coord = coord
63        self._object_bounds = list()
64        self.shape = None
65
66    def transform_to_2d(self, data, affine):
67        raise NotImplementedError("'transform_to_2d' needs to be implemented "
68                                  "in derived classes'")
69
70    def add_object_bounds(self, bounds):
71        """Ensures that axes get rescaled when adding object bounds
72
73        """
74        old_object_bounds = self.get_object_bounds()
75        self._object_bounds.append(bounds)
76        new_object_bounds = self.get_object_bounds()
77
78        if new_object_bounds != old_object_bounds:
79            self.ax.axis(self.get_object_bounds())
80
81    def draw_2d(self, data_2d, data_bounds, bounding_box,
82                type='imshow', **kwargs):
83        # kwargs messaging
84        kwargs['origin'] = 'upper'
85
86        if self.direction == 'y':
87            (xmin, xmax), (_, _), (zmin, zmax) = data_bounds
88            (xmin_, xmax_), (_, _), (zmin_, zmax_) = bounding_box
89        elif self.direction in 'xlr':
90            (_, _), (xmin, xmax), (zmin, zmax) = data_bounds
91            (_, _), (xmin_, xmax_), (zmin_, zmax_) = bounding_box
92        elif self.direction == 'z':
93            (xmin, xmax), (zmin, zmax), (_, _) = data_bounds
94            (xmin_, xmax_), (zmin_, zmax_), (_, _) = bounding_box
95        else:
96            raise ValueError('Invalid value for direction %s' %
97                             self.direction)
98        ax = self.ax
99        # Here we need to do a copy to avoid having the image changing as
100        # we change the data
101        im = getattr(ax, type)(data_2d.copy(),
102                               extent=(xmin, xmax, zmin, zmax),
103                               **kwargs)
104
105        self.add_object_bounds((xmin_, xmax_, zmin_, zmax_))
106        self.shape = data_2d.T.shape
107
108        # The bounds of the object do not take into account a possible
109        # inversion of the axis. As such, we check that the axis is properly
110        # inverted when direction is left
111        if self.direction == 'l' and not (ax.get_xlim()[0] > ax.get_xlim()[1]):
112            ax.invert_xaxis()
113
114        return im
115
116    def get_object_bounds(self):
117        """ Return the bounds of the objects on this axes.
118        """
119        if len(self._object_bounds) == 0:
120            # Nothing plotted yet
121            return -.01, .01, -.01, .01
122        xmins, xmaxs, ymins, ymaxs = np.array(self._object_bounds).T
123        xmax = max(xmaxs.max(), xmins.max())
124        xmin = min(xmins.min(), xmaxs.min())
125        ymax = max(ymaxs.max(), ymins.max())
126        ymin = min(ymins.min(), ymaxs.min())
127
128        return xmin, xmax, ymin, ymax
129
130    def draw_left_right(self, size, bg_color, **kwargs):
131        if self.direction in 'xlr':
132            return
133        ax = self.ax
134        ax.text(.1, .95, 'L',
135                transform=ax.transAxes,
136                horizontalalignment='left',
137                verticalalignment='top',
138                size=size,
139                bbox=dict(boxstyle="square,pad=0",
140                          ec=bg_color, fc=bg_color, alpha=1),
141                **kwargs)
142
143        ax.text(.9, .95, 'R',
144                transform=ax.transAxes,
145                horizontalalignment='right',
146                verticalalignment='top',
147                size=size,
148                bbox=dict(boxstyle="square,pad=0", ec=bg_color, fc=bg_color),
149                **kwargs)
150
151    def draw_scale_bar(self, bg_color, size=5.0, units='cm',
152                       fontproperties=None, frameon=False, loc=4, pad=.1,
153                       borderpad=.5, sep=5, size_vertical=0, label_top=False,
154                       color='black', fontsize=None, **kwargs):
155        """ Adds a scale bar annotation to the display
156
157        Parameters
158        ----------
159        bgcolor : matplotlib color: str or (r, g, b) value
160            The background color of the scale bar annotation.
161
162        size : float, optional
163            Horizontal length of the scale bar, given in `units`.
164            Default=5.0.
165
166        units : str, optional
167            Physical units of the scale bar (`'cm'` or `'mm'`).
168            Default='cm'.
169
170        fontproperties : ``matplotlib.font_manager.FontProperties`` or dict, optional
171            Font properties for the label text.
172
173        frameon : Boolean, optional
174            Whether the scale bar is plotted with a border. Default=False.
175
176        loc : int, optional
177            Location of this scale bar. Valid location codes are documented
178            `here <https://matplotlib.org/mpl_toolkits/axes_grid/\
179            api/anchored_artists_api.html#mpl_toolkits.axes_grid1.\
180            anchored_artists.AnchoredSizeBar>`__.
181            Default=4.
182
183        pad : int of float, optional
184            Padding around the label and scale bar, in fraction of the font
185            size. Default=0.1.
186
187        borderpad : int or float, optional
188            Border padding, in fraction of the font size. Default=0.5.
189
190        sep : int or float, optional
191            Separation between the label and the scale bar, in points.
192            Default=5.
193
194        size_vertical : int or float, optional
195            Vertical length of the size bar, given in `units`. Default=0.
196
197        label_top : bool, optional
198            If True, the label will be over the scale bar. Default=False.
199
200        color : str, optional
201            Color for the scale bar and label. Default='black'.
202
203        fontsize : int, optional
204            Label font size (overwrites the size passed in through the
205            ``fontproperties`` argument).
206
207        **kwargs :
208            Keyworded arguments to pass to
209            ``matplotlib.offsetbox.AnchoredOffsetbox``.
210
211        """
212        axis = self.ax
213        fontproperties = fontproperties or FontProperties()
214        if fontsize:
215            fontproperties.set_size(fontsize)
216        width_mm = size
217        if units == 'cm':
218            width_mm *= 10
219
220        anchor_size_bar = AnchoredSizeBar(
221            axis.transData,
222            width_mm,
223            '%g%s' % (size, units),
224            fontproperties=fontproperties,
225            frameon=frameon,
226            loc=loc,
227            pad=pad,
228            borderpad=borderpad,
229            sep=sep,
230            size_vertical=size_vertical,
231            label_top=label_top,
232            color=color,
233            **kwargs)
234
235        if frameon:
236            anchor_size_bar.patch.set_facecolor(bg_color)
237            anchor_size_bar.patch.set_edgecolor('none')
238        axis.add_artist(anchor_size_bar)
239
240    def draw_position(self, size, bg_color, **kwargs):
241        raise NotImplementedError("'draw_position' should be implemented "
242                                  "in derived classes")
243
244
245###############################################################################
246# class CutAxes
247###############################################################################
248
249class CutAxes(BaseAxes):
250    """ An MPL axis-like object that displays a cut of 3D volumes
251    """
252    def transform_to_2d(self, data, affine):
253        """ Cut the 3D volume into a 2D slice
254
255        Parameters
256        ----------
257        data : 3D ndarray
258            The 3D volume to cut.
259
260        affine : 4x4 ndarray
261            The affine of the volume.
262
263        """
264        coords = [0, 0, 0]
265        coords['xyz'.index(self.direction)] = self.coord
266        x_map, y_map, z_map = [int(np.round(c)) for c in
267                               coord_transform(coords[0],
268                                               coords[1],
269                                               coords[2],
270                                               np.linalg.inv(affine))]
271        if self.direction == 'y':
272            cut = np.rot90(data[:, y_map, :])
273        elif self.direction == 'x':
274            cut = np.rot90(data[x_map, :, :])
275        elif self.direction == 'z':
276            cut = np.rot90(data[:, :, z_map])
277        else:
278            raise ValueError('Invalid value for direction %s' %
279                             self.direction)
280        return cut
281
282    def draw_position(self, size, bg_color, decimals=False, **kwargs):
283        if decimals:
284            text = '%s=%.{}f'.format(decimals)
285            coord = float(self.coord)
286        else:
287            text = '%s=%i'
288            coord = self.coord
289        ax = self.ax
290        ax.text(0, 0, text % (self.direction, coord),
291                transform=ax.transAxes,
292                horizontalalignment='left',
293                verticalalignment='bottom',
294                size=size,
295                bbox=dict(boxstyle="square,pad=0",
296                          ec=bg_color, fc=bg_color, alpha=1),
297                **kwargs)
298
299
300def _get_index_from_direction(direction):
301    """Returns numerical index from direction
302    """
303    directions = ['x', 'y', 'z']
304    try:
305        # l and r are subcases of x
306        if direction in 'lr':
307            index = 0
308        else:
309            index = directions.index(direction)
310    except ValueError:
311        message = (
312            '{0} is not a valid direction. '
313            "Allowed values are 'l', 'r', 'x', 'y' and 'z'").format(direction)
314        raise ValueError(message)
315    return index
316
317
318def _coords_3d_to_2d(coords_3d, direction, return_direction=False):
319    """Project 3d coordinates into 2d ones given the direction of a cut
320    """
321    index = _get_index_from_direction(direction)
322    dimensions = [0, 1, 2]
323    dimensions.pop(index)
324
325    if return_direction:
326        return coords_3d[:, dimensions], coords_3d[:, index]
327
328    return coords_3d[:, dimensions]
329
330
331###############################################################################
332# class GlassBrainAxes
333###############################################################################
334
335class GlassBrainAxes(BaseAxes):
336    """An MPL axis-like object that displays a 2D projection of 3D
337    volumes with a schematic view of the brain.
338
339    """
340    def __init__(self, ax, direction, coord, plot_abs=True, **kwargs):
341        super(GlassBrainAxes, self).__init__(ax, direction, coord)
342        self._plot_abs = plot_abs
343        if ax is not None:
344            object_bounds = glass_brain.plot_brain_schematics(ax,
345                                                              direction,
346                                                              **kwargs)
347            self.add_object_bounds(object_bounds)
348
349    def transform_to_2d(self, data, affine):
350        """ Returns the maximum of the absolute value of the 3D volume
351        along an axis.
352
353        Parameters
354        ----------
355        data : 3D ndarray
356            The 3D volume.
357
358        affine : 4x4 ndarray
359            The affine of the volume.
360
361        """
362        if self.direction in 'xlr':
363            max_axis = 0
364        else:
365            max_axis = '.yz'.index(self.direction)
366
367        # set unselected brain hemisphere activations to 0
368
369        if self.direction == 'l':
370            x_center, _, _, _ = np.dot(np.linalg.inv(affine),
371                                       np.array([0, 0, 0, 1]))
372            data_selection = data[:int(x_center), :, :]
373        elif self.direction == 'r':
374            x_center, _, _, _ = np.dot(np.linalg.inv(affine),
375                                       np.array([0, 0, 0, 1]))
376            data_selection = data[int(x_center):, :, :]
377        else:
378            data_selection = data
379
380        # We need to make sure data_selection is not empty in the x axis
381        # This should be the case since we expect images in MNI space
382        if data_selection.shape[0] == 0:
383            data_selection = data
384
385        if not self._plot_abs:
386            # get the shape of the array we are projecting to
387            new_shape = list(data.shape)
388            del new_shape[max_axis]
389
390            # generate a 3D indexing array that points to max abs value in the
391            # current projection
392            a1, a2 = np.indices(new_shape)
393            inds = [a1, a2]
394            inds.insert(max_axis, np.abs(data_selection).argmax(axis=max_axis))
395
396            # take the values where the absolute value of the projection
397            # is the highest
398            maximum_intensity_data = data_selection[tuple(inds)]
399        else:
400            maximum_intensity_data = np.abs(data_selection).max(axis=max_axis)
401
402        # This work around can be removed bumping matplotlib > 2.1.0. See #1815
403        # in nilearn for the invention of this work around
404
405        if self.direction == 'l' and data_selection.min() is np.ma.masked and \
406                not (self.ax.get_xlim()[0] > self.ax.get_xlim()[1]):
407            self.ax.invert_xaxis()
408
409        return np.rot90(maximum_intensity_data)
410
411    def draw_position(self, size, bg_color, **kwargs):
412        # It does not make sense to draw crosses for the position of
413        # the cuts since we are taking the max along one axis
414        pass
415
416    def _add_markers(self, marker_coords, marker_color, marker_size, **kwargs):
417        """Plot markers
418
419        In the case of 'l' and 'r' directions (for hemispheric projections),
420        markers in the coordinate x == 0 are included in both hemispheres.
421
422        """
423        marker_coords_2d = _coords_3d_to_2d(marker_coords, self.direction)
424        xdata, ydata = marker_coords_2d.T
425
426        # Allow markers only in their respective hemisphere when appropriate
427        if self.direction in 'lr':
428            if not isinstance(marker_color, str) and \
429                    not isinstance(marker_color, np.ndarray):
430                marker_color = np.asarray(marker_color)
431            relevant_coords = []
432            xcoords, ycoords, zcoords = marker_coords.T
433            for cidx, xc in enumerate(xcoords):
434                if self.direction == 'r' and xc >= 0:
435                    relevant_coords.append(cidx)
436                elif self.direction == 'l' and xc <= 0:
437                    relevant_coords.append(cidx)
438            xdata = xdata[relevant_coords]
439            ydata = ydata[relevant_coords]
440            # if marker_color is string for example 'red' or 'blue', then
441            # we pass marker_color as it is to matplotlib scatter without
442            # making any selection in 'l' or 'r' color.
443            # More likely that user wants to display all nodes to be in
444            # same color.
445            if not isinstance(marker_color, str) and \
446                    len(marker_color) != 1:
447                marker_color = marker_color[relevant_coords]
448
449            if not isinstance(marker_size, numbers.Number):
450                marker_size = np.asarray(marker_size)[relevant_coords]
451
452        defaults = {'marker': 'o',
453                    'zorder': 1000}
454        for k, v in defaults.items():
455            kwargs.setdefault(k, v)
456
457        self.ax.scatter(xdata, ydata, s=marker_size,
458                        c=marker_color, **kwargs)
459
460    def _add_lines(self, line_coords, line_values, cmap,
461                   vmin=None, vmax=None, directed=False, **kwargs):
462        """Plot lines
463
464        Parameters
465        ----------
466        line_coords : list of numpy arrays of shape (2, 3)
467            3d coordinates of lines start points and end points.
468
469        line_values : array_like
470            Values of the lines.
471
472        cmap : colormap
473            Colormap used to map line_values to a color.
474
475        vmin, vmax : float, optional
476            If not None, either or both of these values will be used to
477            as the minimum and maximum values to color lines. If None are
478            supplied the maximum absolute value within the given threshold
479            will be used as minimum (multiplied by -1) and maximum
480            coloring levels.
481
482        directed : boolean, optional
483            Add arrows instead of lines if set to True. Use this when plotting
484            directed graphs for example. Default=False.
485
486        kwargs : dict
487            Additional arguments to pass to matplotlib Line2D.
488
489        """
490        # colormap for colorbar
491        self.cmap = cmap
492        if vmin is None and vmax is None:
493            abs_line_values_max = np.abs(line_values).max()
494            vmin = -abs_line_values_max
495            vmax = abs_line_values_max
496        elif vmin is None:
497            if vmax > 0:
498                vmin = -vmax
499            else:
500                raise ValueError(
501                    "If vmax is set to a non-positive number "
502                    "then vmin needs to be specified"
503                )
504        elif vmax is None:
505            if vmin < 0:
506                vmax = -vmin
507            else:
508                raise ValueError(
509                    "If vmin is set to a non-negative number "
510                    "then vmax needs to be specified"
511                )
512        norm = colors.Normalize(vmin=vmin,
513                                vmax=vmax)
514        # normalization useful for colorbar
515        self.norm = norm
516        abs_norm = colors.Normalize(vmin=0,
517                                    vmax=vmax)
518        value_to_color = plt.cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba
519
520        # Allow lines only in their respective hemisphere when appropriate
521        if self.direction in 'lr':
522            relevant_lines = []
523            for lidx, line in enumerate(line_coords):
524                if self.direction == 'r':
525                    if line[0, 0] >= 0 and line[1, 0] >= 0:
526                        relevant_lines.append(lidx)
527                elif self.direction == 'l':
528                    if line[0, 0] < 0 and line[1, 0] < 0:
529                        relevant_lines.append(lidx)
530            line_coords = np.array(line_coords)[relevant_lines]
531            line_values = line_values[relevant_lines]
532
533        for start_end_point_3d, line_value in zip(
534                line_coords, line_values):
535            start_end_point_2d = _coords_3d_to_2d(start_end_point_3d,
536                                                  self.direction)
537
538            color = value_to_color(line_value)
539            abs_line_value = abs(line_value)
540            linewidth = 1 + 2 * abs_norm(abs_line_value)
541            # Hacky way to put the strongest connections on top of the weakest
542            # note sign does not matter hence using 'abs'
543            zorder = 10 + 10 * abs_norm(abs_line_value)
544            this_kwargs = {'color': color, 'linewidth': linewidth,
545                           'zorder': zorder}
546            # kwargs should have priority over this_kwargs so that the
547            # user can override the default logic
548            this_kwargs.update(kwargs)
549            xdata, ydata = start_end_point_2d.T
550            # If directed is True, add an arrow
551            if directed:
552                dx = xdata[1] - xdata[0]
553                dy = ydata[1] - ydata[0]
554                # Hack to avoid empty arrows to crash with
555                # matplotlib versions older than 3.1
556                # This can be removed once support for
557                # matplotlib pre 3.1 has been dropped.
558                if dx == 0 and dy == 0:
559                    arrow = FancyArrow(xdata[0], ydata[0],
560                                       dx, dy)
561                else:
562                    arrow = FancyArrow(xdata[0], ydata[0],
563                                       dx, dy,
564                                       length_includes_head=True,
565                                       width=linewidth,
566                                       head_width=3*linewidth,
567                                       **this_kwargs)
568                self.ax.add_patch(arrow)
569            # Otherwise a line
570            else:
571                line = lines.Line2D(xdata, ydata, **this_kwargs)
572                self.ax.add_line(line)
573
574
575###############################################################################
576# class BaseSlicer
577###############################################################################
578
579class BaseSlicer(object):
580    """ The main purpose of these class is to have auto adjust of axes size
581        to the data with different layout of cuts.
582
583    """
584    # This actually encodes the figsize for only one axe
585    _default_figsize = [2.2, 2.6]
586    _axes_class = CutAxes
587
588    def __init__(self, cut_coords, axes=None, black_bg=False,
589                 brain_color=(0.5, 0.5, 0.5), **kwargs):
590        """ Create 3 linked axes for plotting orthogonal cuts.
591
592        Parameters
593        ----------
594        cut_coords : 3 tuple of ints
595            The cut position, in world space.
596
597        axes : matplotlib axes object, optional
598            The axes that will be subdivided in 3.
599
600        black_bg : boolean, optional
601            If True, the background of the figure will be put to
602            black. If you wish to save figures with a black background,
603            you will need to pass "facecolor='k', edgecolor='k'"
604            to matplotlib.pyplot.savefig. Default=False.
605
606        brain_color : tuple, optional
607            The brain color to use as the background color (e.g., for
608            transparent colorbars).
609            Default=(0.5, 0.5, 0.5)
610
611        """
612        self.cut_coords = cut_coords
613        if axes is None:
614            axes = plt.axes((0., 0., 1., 1.))
615            axes.axis('off')
616        self.frame_axes = axes
617        axes.set_zorder(1)
618        bb = axes.get_position()
619        self.rect = (bb.x0, bb.y0, bb.x1, bb.y1)
620        self._black_bg = black_bg
621        self._brain_color = brain_color
622        self._colorbar = False
623        self._colorbar_width = 0.05 * bb.width
624        self._colorbar_margin = dict(left=0.25 * bb.width,
625                                     right=0.02 * bb.width,
626                                     top=0.05 * bb.height,
627                                     bottom=0.05 * bb.height)
628        self._init_axes(**kwargs)
629
630    @staticmethod
631    def find_cut_coords(img=None, threshold=None, cut_coords=None):
632        # Implement this as a staticmethod or a classmethod when
633        # subclassing
634        raise NotImplementedError
635
636    @classmethod
637    def init_with_figure(cls, img, threshold=None,
638                         cut_coords=None, figure=None, axes=None,
639                         black_bg=False, leave_space=False, colorbar=False,
640                         brain_color=(0.5, 0.5, 0.5), **kwargs):
641        "Initialize the slicer with an image"
642        # deal with "fake" 4D images
643        if img is not None and img is not False:
644            img = _utils.check_niimg_3d(img)
645
646        cut_coords = cls.find_cut_coords(img, threshold, cut_coords)
647
648        if isinstance(axes, plt.Axes) and figure is None:
649            figure = axes.figure
650
651        if not isinstance(figure, plt.Figure):
652            # Make sure that we have a figure
653            figsize = cls._default_figsize[:]
654
655            # Adjust for the number of axes
656            figsize[0] *= len(cut_coords)
657
658            # Make space for the colorbar
659            if colorbar:
660                figsize[0] += .7
661
662            facecolor = 'k' if black_bg else 'w'
663
664            if leave_space:
665                figsize[0] += 3.4
666            figure = plt.figure(figure, figsize=figsize,
667                                facecolor=facecolor)
668        if isinstance(axes, plt.Axes):
669            assert axes.figure is figure, ("The axes passed are not "
670                                           "in the figure")
671
672        if axes is None:
673            axes = [0., 0., 1., 1.]
674            if leave_space:
675                axes = [0.3, 0, .7, 1.]
676        if isinstance(axes, collections.abc.Sequence):
677            axes = figure.add_axes(axes)
678        # People forget to turn their axis off, or to set the zorder, and
679        # then they cannot see their slicer
680        axes.axis('off')
681        return cls(cut_coords, axes, black_bg, brain_color, **kwargs)
682
683    def title(self, text, x=0.01, y=0.99, size=15, color=None, bgcolor=None,
684              alpha=1, **kwargs):
685        """ Write a title to the view.
686
687        Parameters
688        ----------
689        text : string
690            The text of the title.
691
692        x : float, optional
693            The horizontal position of the title on the frame in
694            fraction of the frame width. Default=0.01.
695
696        y : float, optional
697            The vertical position of the title on the frame in
698            fraction of the frame height. Default=0.99.
699
700        size : integer, optional
701            The size of the title text. Default=15.
702
703        color : matplotlib color specifier, optional
704            The color of the font of the title.
705
706        bgcolor : matplotlib color specifier, optional
707            The color of the background of the title.
708
709        alpha : float, optional
710            The alpha value for the background. Default=1.
711
712        kwargs :
713            Extra keyword arguments are passed to matplotlib's text
714            function.
715
716        """
717        if color is None:
718            color = 'k' if self._black_bg else 'w'
719        if bgcolor is None:
720            bgcolor = 'w' if self._black_bg else 'k'
721        if hasattr(self, '_cut_displayed'):
722            # Adapt to the case of mosaic plotting
723            if isinstance(self.cut_coords, dict):
724                first_axe = self._cut_displayed[-1]
725                first_axe = (first_axe, self.cut_coords[first_axe][0])
726            else:
727                first_axe = self._cut_displayed[0]
728        else:
729            first_axe = self.cut_coords[0]
730        ax = self.axes[first_axe].ax
731        ax.text(x, y, text,
732                transform=self.frame_axes.transAxes,
733                horizontalalignment='left',
734                verticalalignment='top',
735                size=size, color=color,
736                bbox=dict(boxstyle="square,pad=.3",
737                          ec=bgcolor, fc=bgcolor, alpha=alpha),
738                zorder=1000,
739                **kwargs)
740        ax.set_zorder(1000)
741
742    def add_overlay(self, img, threshold=1e-6, colorbar=False, **kwargs):
743        """ Plot a 3D map in all the views.
744
745        Parameters
746        -----------
747        img : Niimg-like object
748            See http://nilearn.github.io/manipulating_images/input_output.html
749            If it is a masked array, only the non-masked part will be plotted.
750
751        threshold : Int or Float or None, optional
752            If None is given, the maps are not thresholded.
753            If a number is given, it is used to threshold the maps:
754            values below the threshold (in absolute value) are
755            plotted as transparent. Default=1e-6.
756
757        colorbar : boolean, optional
758            If True, display a colorbar on the right of the plots.
759            Default=False.
760
761        kwargs :
762            Extra keyword arguments are passed to imshow.
763
764        """
765        if colorbar and self._colorbar:
766            raise ValueError("This figure already has an overlay with a "
767                             "colorbar.")
768        else:
769            self._colorbar = colorbar
770
771        img = _utils.check_niimg_3d(img)
772
773        # Make sure that add_overlay shows consistent default behavior
774        # with plot_stat_map
775        kwargs.setdefault('interpolation', 'nearest')
776        ims = self._map_show(img, type='imshow', threshold=threshold, **kwargs)
777
778        # `ims` can be empty in some corner cases, look at test_img_plotting.test_outlier_cut_coords.
779        if colorbar and ims:
780            self._show_colorbar(ims[0].cmap, ims[0].norm, threshold)
781
782        plt.draw_if_interactive()
783
784    def add_contours(self, img, threshold=1e-6, filled=False, **kwargs):
785        """ Contour a 3D map in all the views.
786
787        Parameters
788        -----------
789        img : Niimg-like object
790            See http://nilearn.github.io/manipulating_images/input_output.html
791            Provides image to plot.
792
793        threshold : Int or Float or None, optional
794            If None is given, the maps are not thresholded.
795            If a number is given, it is used to threshold the maps,
796            values below the threshold (in absolute value) are plotted
797            as transparent. Default=1e-6.
798
799        filled : boolean, optional
800            If filled=True, contours are displayed with color fillings.
801            Default=False.
802
803        kwargs :
804            Extra keyword arguments are passed to contour, see the
805            documentation of pylab.contour and see pylab.contourf documentation
806            for arguments related to contours with fillings.
807            Useful, arguments are typical "levels", which is a
808            list of values to use for plotting a contour or contour
809            fillings (if filled=True), and
810            "colors", which is one color or a list of colors for
811            these contours.
812
813        Notes
814        -----
815        If colors are not specified, default coloring choices
816        (from matplotlib) for contours and contour_fillings can be
817        different.
818
819        """
820        if not filled:
821            threshold = None
822        self._map_show(img, type='contour', threshold=threshold, **kwargs)
823        if filled:
824            if 'levels' in kwargs:
825                levels = kwargs['levels']
826                if len(levels) <= 1:
827                    # contour fillings levels should be given as (lower, upper).
828                    levels.append(np.inf)
829
830            self._map_show(img, type='contourf', threshold=threshold, **kwargs)
831
832        plt.draw_if_interactive()
833
834    def _map_show(self, img, type='imshow',
835                  resampling_interpolation='continuous',
836                  threshold=None, **kwargs):
837        # In the special case where the affine of img is not diagonal,
838        # the function `reorder_img` will trigger a resampling
839        # of the provided image with a continuous interpolation
840        # since this is the default value here. In the special
841        # case where this image is binary, such as when this function
842        # is called from `add_contours`, continuous interpolation
843        # does not make sense and we turn to nearest interpolation instead.
844        if _utils.niimg._is_binary_niimg(img):
845            img = reorder_img(img, resample='nearest')
846        else:
847            img = reorder_img(img, resample=resampling_interpolation)
848        threshold = float(threshold) if threshold is not None else None
849
850        if threshold is not None:
851            data = _utils.niimg._safe_get_data(img, ensure_finite=True)
852            if threshold == 0:
853                data = np.ma.masked_equal(data, 0, copy=False)
854            else:
855                data = np.ma.masked_inside(data, -threshold, threshold,
856                                           copy=False)
857            img = new_img_like(img, data, img.affine)
858
859        affine = img.affine
860        data = _utils.niimg._safe_get_data(img, ensure_finite=True)
861        data_bounds = get_bounds(data.shape, affine)
862        (xmin, xmax), (ymin, ymax), (zmin, zmax) = data_bounds
863
864        xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \
865            xmin, xmax, ymin, ymax, zmin, zmax
866
867        # Compute tight bounds
868        if type in ('contour', 'contourf'):
869            # Define a pseudo threshold to have a tight bounding box
870            if 'levels' in kwargs:
871                thr = 0.9 * np.min(np.abs(kwargs['levels']))
872            else:
873                thr = 1e-6
874            not_mask = np.logical_or(data > thr, data < -thr)
875            xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \
876                get_mask_bounds(new_img_like(img, not_mask, affine))
877        elif hasattr(data, 'mask') and isinstance(data.mask, np.ndarray):
878            not_mask = np.logical_not(data.mask)
879            xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = \
880                get_mask_bounds(new_img_like(img, not_mask, affine))
881
882        data_2d_list = []
883        for display_ax in self.axes.values():
884            try:
885                data_2d = display_ax.transform_to_2d(data, affine)
886            except IndexError:
887                # We are cutting outside the indices of the data
888                data_2d = None
889
890            data_2d_list.append(data_2d)
891
892        if kwargs.get('vmin') is None:
893            kwargs['vmin'] = np.ma.min([d.min() for d in data_2d_list
894                                        if d is not None])
895        if kwargs.get('vmax') is None:
896            kwargs['vmax'] = np.ma.max([d.max() for d in data_2d_list
897                                        if d is not None])
898
899        bounding_box = (xmin_, xmax_), (ymin_, ymax_), (zmin_, zmax_)
900        ims = []
901        to_iterate_over = zip(self.axes.values(), data_2d_list)
902        for display_ax, data_2d in to_iterate_over:
903            if data_2d is not None and data_2d.min() is not np.ma.masked:
904                # If data_2d is completely masked, then there is nothing to
905                # plot. Hence, no point to do imshow(). Moreover, we see
906                # problem came up with matplotlib 2.1.0 (issue #9280) when
907                # data is completely masked or with numpy < 1.14
908                # (issue #4595). This work around can be removed when bumping
909                # matplotlib version above 2.1.0
910                im = display_ax.draw_2d(data_2d, data_bounds, bounding_box,
911                                        type=type, **kwargs)
912                ims.append(im)
913        return ims
914
915    def _show_colorbar(self, cmap, norm, threshold=None):
916        """Displays the colorbar.
917
918        Parameters
919        ----------
920        cmap : a matplotlib colormap
921            The colormap used.
922
923        norm : a matplotlib.colors.Normalize object
924            This object is typically found as the 'norm' attribute of an
925            matplotlib.image.AxesImage.
926
927        threshold : float or None, optional
928            The absolute value at which the colorbar is thresholded.
929
930        """
931        if threshold is None:
932            offset = 0
933        else:
934            offset = threshold
935        if offset > norm.vmax:
936            offset = norm.vmax
937
938        # create new  axis for the colorbar
939        figure = self.frame_axes.figure
940        _, y0, x1, y1 = self.rect
941        height = y1 - y0
942        x_adjusted_width = self._colorbar_width / len(self.axes)
943        x_adjusted_margin = self._colorbar_margin['right'] / len(self.axes)
944        lt_wid_top_ht = [x1 - (x_adjusted_width + x_adjusted_margin),
945                         y0 + self._colorbar_margin['top'],
946                         x_adjusted_width,
947                         height - (self._colorbar_margin['top'] +
948                                   self._colorbar_margin['bottom'])]
949        self._colorbar_ax = figure.add_axes(lt_wid_top_ht)
950        if LooseVersion(matplotlib.__version__) >= LooseVersion("1.6"):
951            self._colorbar_ax.set_facecolor('w')
952        else:
953            self._colorbar_ax.set_axis_bgcolor('w')
954
955        our_cmap = mpl_cm.get_cmap(cmap)
956        # edge case where the data has a single value
957        # yields a cryptic matplotlib error message
958        # when trying to plot the color bar
959        nb_ticks = 5 if norm.vmin != norm.vmax else 1
960        ticks = np.linspace(norm.vmin, norm.vmax, nb_ticks)
961        bounds = np.linspace(norm.vmin, norm.vmax, our_cmap.N)
962
963        # some colormap hacking
964        cmaplist = [our_cmap(i) for i in range(our_cmap.N)]
965        transparent_start = int(norm(-offset, clip=True) * (our_cmap.N - 1))
966        transparent_stop = int(norm(offset, clip=True) * (our_cmap.N - 1))
967        for i in range(transparent_start, transparent_stop):
968            cmaplist[i] = self._brain_color + (0.,)  # transparent
969        if norm.vmin == norm.vmax:  # len(np.unique(data)) == 1 ?
970            return
971        else:
972            our_cmap = colors.LinearSegmentedColormap.from_list(
973                'Custom cmap', cmaplist, our_cmap.N)
974
975        self._cbar = ColorbarBase(
976            self._colorbar_ax, ticks=ticks, norm=norm,
977            orientation='vertical', cmap=our_cmap, boundaries=bounds,
978            spacing='proportional', format='%.2g')
979        self._cbar.ax.set_facecolor(self._brain_color)
980
981        self._colorbar_ax.yaxis.tick_left()
982        tick_color = 'w' if self._black_bg else 'k'
983        outline_color = 'w' if self._black_bg else 'k'
984
985        for tick in self._colorbar_ax.yaxis.get_ticklabels():
986            tick.set_color(tick_color)
987        self._colorbar_ax.yaxis.set_tick_params(width=0)
988        self._cbar.outline.set_edgecolor(outline_color)
989
990    def add_edges(self, img, color='r'):
991        """ Plot the edges of a 3D map in all the views.
992
993        Parameters
994        ----------
995        img : Niimg-like object
996            See http://nilearn.github.io/manipulating_images/input_output.html
997            The 3D map to be plotted.
998            If it is a masked array, only the non-masked part will be plotted.
999
1000        color : matplotlib color: string or (r, g, b) value
1001            The color used to display the edge map.
1002            Default='r'.
1003
1004        """
1005        img = reorder_img(img, resample='continuous')
1006        data = get_data(img)
1007        affine = img.affine
1008        single_color_cmap = colors.ListedColormap([color])
1009        data_bounds = get_bounds(data.shape, img.affine)
1010
1011        # For each ax, cut the data and plot it
1012        for display_ax in self.axes.values():
1013            try:
1014                data_2d = display_ax.transform_to_2d(data, affine)
1015                edge_mask = _edge_map(data_2d)
1016            except IndexError:
1017                # We are cutting outside the indices of the data
1018                continue
1019            display_ax.draw_2d(edge_mask, data_bounds, data_bounds,
1020                               type='imshow', cmap=single_color_cmap)
1021
1022        plt.draw_if_interactive()
1023
1024    def add_markers(self, marker_coords, marker_color='r', marker_size=30,
1025                    **kwargs):
1026        """Add markers to the plot.
1027
1028        Parameters
1029        ----------
1030        marker_coords : array of size (n_markers, 3)
1031            Coordinates of the markers to plot. For each slice, only markers
1032            that are 2 millimeters away from the slice are plotted.
1033
1034        marker_color : pyplot compatible color or list of shape (n_markers,), optional
1035            List of colors for each marker that can be string or matplotlib colors.
1036            Default='r'.
1037
1038        marker_size : single float or list of shape (n_markers,), optional
1039            Size in pixel for each marker. Default=30.
1040
1041        """
1042        defaults = {'marker': 'o',
1043                    'zorder': 1000}
1044        marker_coords = np.asanyarray(marker_coords)
1045        for k, v in defaults.items():
1046            kwargs.setdefault(k, v)
1047
1048        for display_ax in self.axes.values():
1049            direction = display_ax.direction
1050            coord = display_ax.coord
1051            marker_coords_2d, third_d = _coords_3d_to_2d(
1052                marker_coords, direction, return_direction=True)
1053            xdata, ydata = marker_coords_2d.T
1054	        # Allow markers only in their respective hemisphere when appropriate
1055            marker_color_ = marker_color
1056            if direction in ('lr'):
1057                if (not isinstance(marker_color, str) and
1058	            not isinstance(marker_color, np.ndarray)):
1059                    marker_color_ = np.asarray(marker_color)
1060                xcoords, ycoords, zcoords = marker_coords.T
1061                if direction == 'r':
1062                    relevant_coords = (xcoords >= 0)
1063                elif direction == 'l':
1064                    relevant_coords = (xcoords <= 0)
1065                xdata = xdata[relevant_coords]
1066                ydata = ydata[relevant_coords]
1067                if (not isinstance(marker_color, str) and
1068                        len(marker_color) != 1):
1069                    marker_color_ = marker_color_[relevant_coords]
1070            # Check if coord has integer represents a cut in direction
1071            # to follow the heuristic. If no foreground image is given
1072            # coordinate is empty or None. This case is valid for plotting
1073            # markers on glass brain without any foreground image.
1074            if isinstance(coord, numbers.Number):
1075                # Heuristic that plots only markers that are 2mm away
1076                # from the current slice.
1077                # XXX: should we keep this heuristic?
1078                mask = np.abs(third_d - coord) <= 2.
1079                xdata = xdata[mask]
1080                ydata = ydata[mask]
1081            display_ax.ax.scatter(xdata, ydata, s=marker_size,
1082                                  c=marker_color_, **kwargs)
1083
1084    def annotate(self, left_right=True, positions=True, scalebar=False,
1085                 size=12, scale_size=5.0, scale_units='cm', scale_loc=4,
1086                 decimals=0, **kwargs):
1087        """Add annotations to the plot.
1088
1089        Parameters
1090        ----------
1091        left_right : boolean, optional
1092            If left_right is True, annotations indicating which side
1093            is left and which side is right are drawn. Default=True.
1094
1095        positions : boolean, optional
1096            If positions is True, annotations indicating the
1097            positions of the cuts are drawn. Default=True.
1098
1099        scalebar : boolean, optional
1100            If ``True``, cuts are annotated with a reference scale bar.
1101            For finer control of the scale bar, please check out
1102            the draw_scale_bar method on the axes in "axes" attribute of
1103            this object. Default=False.
1104
1105        size : integer, optional
1106            The size of the text used. Default=12.
1107
1108        scale_size : number, optional
1109            The length of the scalebar, in units of scale_units.
1110            Default=5.0.
1111
1112        scale_units : {'cm', 'mm'}, optional
1113            The units for the scalebar. Default='cm'.
1114
1115        scale_loc : integer, optional
1116            The positioning for the scalebar. Default=4.
1117            Valid location codes are:
1118
1119            - 'upper right'  : 1
1120            - 'upper left'   : 2
1121            - 'lower left'   : 3
1122            - 'lower right'  : 4
1123            - 'right'        : 5
1124            - 'center left'  : 6
1125            - 'center right' : 7
1126            - 'lower center' : 8
1127            - 'upper center' : 9
1128            - 'center'       : 10
1129
1130        decimals : integer, optional
1131            Number of decimal places on slice position annotation. If zero,
1132            the slice position is integer without decimal point.
1133            Default=0.
1134
1135        kwargs :
1136            Extra keyword arguments are passed to matplotlib's text
1137            function.
1138
1139        """
1140        kwargs = kwargs.copy()
1141        if 'color' not in kwargs:
1142            if self._black_bg:
1143                kwargs['color'] = 'w'
1144            else:
1145                kwargs['color'] = 'k'
1146
1147        bg_color = ('k' if self._black_bg else 'w')
1148
1149        if left_right:
1150            for display_axis in self.axes.values():
1151                display_axis.draw_left_right(size=size, bg_color=bg_color,
1152                                             **kwargs)
1153
1154        if positions:
1155            for display_axis in self.axes.values():
1156                display_axis.draw_position(size=size, bg_color=bg_color,
1157                                           decimals=decimals,
1158                                           **kwargs)
1159
1160        if scalebar:
1161            axes = self.axes.values()
1162            for display_axis in axes:
1163                display_axis.draw_scale_bar(bg_color=bg_color,
1164                                            fontsize=size,
1165                                            size=scale_size,
1166                                            units=scale_units,
1167                                            loc=scale_loc,
1168                                            **kwargs)
1169
1170    def close(self):
1171        """ Close the figure. This is necessary to avoid leaking memory.
1172        """
1173        plt.close(self.frame_axes.figure.number)
1174
1175    def savefig(self, filename, dpi=None):
1176        """ Save the figure to a file
1177
1178        Parameters
1179        ----------
1180        filename : string
1181            The file name to save to. Its extension determines the
1182            file type, typically '.png', '.svg' or '.pdf'.
1183
1184        dpi : None or scalar, optional
1185            The resolution in dots per inch.
1186
1187        """
1188        facecolor = edgecolor = 'k' if self._black_bg else 'w'
1189        self.frame_axes.figure.savefig(filename, dpi=dpi,
1190                                       facecolor=facecolor,
1191                                       edgecolor=edgecolor)
1192
1193
1194###############################################################################
1195# class OrthoSlicer
1196###############################################################################
1197
1198class OrthoSlicer(BaseSlicer):
1199    """ A class to create 3 linked axes for plotting orthogonal
1200    cuts of 3D maps.
1201
1202    Attributes
1203    ----------
1204    axes : dictionary of axes
1205        The 3 axes used to plot each view.
1206
1207    frame_axes : axes
1208        The axes framing the whole set of views.
1209
1210    Notes
1211    -----
1212    The extent of the different axes are adjusted to fit the data
1213    best in the viewing area.
1214
1215    """
1216    _cut_displayed = 'yxz'
1217    _axes_class = CutAxes
1218
1219    @classmethod
1220    def find_cut_coords(cls, img=None, threshold=None, cut_coords=None):
1221        "Instantiate the slicer and find cut coordinates"
1222        if cut_coords is None:
1223            if img is None or img is False:
1224                cut_coords = (0, 0, 0)
1225            else:
1226                cut_coords = find_xyz_cut_coords(
1227                    img, activation_threshold=threshold)
1228            cut_coords = [cut_coords['xyz'.find(c)]
1229                          for c in sorted(cls._cut_displayed)]
1230        return cut_coords
1231
1232    def _init_axes(self, **kwargs):
1233        cut_coords = self.cut_coords
1234        if len(cut_coords) != len(self._cut_displayed):
1235            raise ValueError('The number cut_coords passed does not'
1236                             ' match the display_mode')
1237        x0, y0, x1, y1 = self.rect
1238        facecolor = 'k' if self._black_bg else 'w'
1239        # Create our axes:
1240        self.axes = dict()
1241        for index, direction in enumerate(self._cut_displayed):
1242            fh = self.frame_axes.get_figure()
1243            ax = fh.add_axes([0.3 * index * (x1 - x0) + x0, y0,
1244                              .3 * (x1 - x0), y1 - y0], aspect='equal')
1245            if LooseVersion(matplotlib.__version__) >= LooseVersion("1.6"):
1246                ax.set_facecolor(facecolor)
1247            else:
1248                ax.set_axis_bgcolor(facecolor)
1249
1250            ax.axis('off')
1251            coord = self.cut_coords[
1252                sorted(self._cut_displayed).index(direction)]
1253            display_ax = self._axes_class(ax, direction, coord, **kwargs)
1254            self.axes[direction] = display_ax
1255            ax.set_axes_locator(self._locator)
1256
1257        if self._black_bg:
1258            for ax in self.axes.values():
1259                ax.ax.imshow(np.zeros((2, 2, 3)),
1260                             extent=[-5000, 5000, -5000, 5000],
1261                             zorder=-500, aspect='equal')
1262
1263            # To have a black background in PDF, we need to create a
1264            # patch in black for the background
1265            self.frame_axes.imshow(np.zeros((2, 2, 3)),
1266                                   extent=[-5000, 5000, -5000, 5000],
1267                                   zorder=-500, aspect='auto')
1268            self.frame_axes.set_zorder(-1000)
1269
1270    def _locator(self, axes, renderer):
1271        """ The locator function used by matplotlib to position axes.
1272        Here we put the logic used to adjust the size of the axes.
1273
1274        """
1275        x0, y0, x1, y1 = self.rect
1276        width_dict = dict()
1277        # A dummy axes, for the situation in which we are not plotting
1278        # all three (x, y, z) cuts
1279        dummy_ax = self._axes_class(None, None, None)
1280        width_dict[dummy_ax.ax] = 0
1281        display_ax_dict = self.axes
1282
1283        if self._colorbar:
1284            adjusted_width = self._colorbar_width / len(self.axes)
1285            right_margin = self._colorbar_margin['right'] / len(self.axes)
1286            ticks_margin = self._colorbar_margin['left'] / len(self.axes)
1287            x1 = x1 - (adjusted_width + ticks_margin + right_margin)
1288
1289        for display_ax in display_ax_dict.values():
1290            bounds = display_ax.get_object_bounds()
1291            if not bounds:
1292                # This happens if the call to _map_show was not
1293                # successful. As it happens asynchronously (during a
1294                # refresh of the figure) we capture the problem and
1295                # ignore it: it only adds a non informative traceback
1296                bounds = [0, 1, 0, 1]
1297            xmin, xmax, ymin, ymax = bounds
1298            width_dict[display_ax.ax] = (xmax - xmin)
1299
1300        total_width = float(sum(width_dict.values()))
1301        for ax, width in width_dict.items():
1302            width_dict[ax] = width / total_width * (x1 - x0)
1303
1304        direction_ax = []
1305        for d in self._cut_displayed:
1306            direction_ax.append(display_ax_dict.get(d, dummy_ax).ax)
1307        left_dict = dict()
1308        for idx, ax in enumerate(direction_ax):
1309            left_dict[ax] = x0
1310            for prev_ax in direction_ax[:idx]:
1311                left_dict[ax] += width_dict[prev_ax]
1312
1313        return transforms.Bbox([[left_dict[axes], y0],
1314                               [left_dict[axes] + width_dict[axes], y1]])
1315
1316    def draw_cross(self, cut_coords=None, **kwargs):
1317        """ Draw a crossbar on the plot to show where the cut is
1318        performed.
1319
1320        Parameters
1321        ----------
1322        cut_coords : 3-tuple of floats, optional
1323            The position of the cross to draw. If none is passed, the
1324            ortho_slicer's cut coordinates are used.
1325
1326        kwargs :
1327            Extra keyword arguments are passed to axhline
1328
1329        """
1330        if cut_coords is None:
1331            cut_coords = self.cut_coords
1332        coords = dict()
1333        for direction in 'xyz':
1334            coord = None
1335            if direction in self._cut_displayed:
1336                coord = cut_coords[
1337                    sorted(self._cut_displayed).index(direction)]
1338            coords[direction] = coord
1339        x, y, z = coords['x'], coords['y'], coords['z']
1340
1341        kwargs = kwargs.copy()
1342        if 'color' not in kwargs:
1343            if self._black_bg:
1344                kwargs['color'] = '.8'
1345            else:
1346                kwargs['color'] = 'k'
1347
1348        if 'y' in self.axes:
1349            ax = self.axes['y'].ax
1350            if x is not None:
1351                ax.axvline(x, ymin=.05, ymax=.95, **kwargs)
1352            if z is not None:
1353                ax.axhline(z, **kwargs)
1354
1355        if 'x' in self.axes:
1356            ax = self.axes['x'].ax
1357            if y is not None:
1358                ax.axvline(y, ymin=.05, ymax=.95, **kwargs)
1359            if z is not None:
1360                ax.axhline(z, xmax=.95, **kwargs)
1361
1362        if 'z' in self.axes:
1363            ax = self.axes['z'].ax
1364            if x is not None:
1365                ax.axvline(x, ymin=.05, ymax=.95, **kwargs)
1366            if y is not None:
1367                ax.axhline(y, **kwargs)
1368
1369
1370###############################################################################
1371# class TiledSlicer
1372###############################################################################
1373
1374class TiledSlicer(BaseSlicer):
1375    """ A class to create 3 axes for plotting orthogonal
1376    cuts of 3D maps, organized in a 2x2 grid.
1377
1378    Attributes
1379    ----------
1380    axes : dictionary of axes
1381        The 3 axes used to plot each view.
1382
1383    frame_axes : axes
1384        The axes framing the whole set of views.
1385
1386    Notes
1387    -----
1388    The extent of the different axes are adjusted to fit the data
1389    best in the viewing area.
1390
1391    """
1392    _cut_displayed = 'yxz'
1393    _axes_class = CutAxes
1394    _default_figsize = [2.0, 6.0]
1395
1396    @classmethod
1397    def find_cut_coords(cls, img=None, threshold=None, cut_coords=None):
1398        """Instantiate the slicer and find cut coordinates.
1399
1400        Parameters
1401        ----------
1402        img : 3D Nifti1Image
1403            The brain map.
1404
1405        threshold : float, optional
1406            The lower threshold to the positive activation. If None, the
1407            activation threshold is computed using the 80% percentile of
1408            the absolute value of the map.
1409
1410        cut_coords : list of float, optional
1411            xyz world coordinates of cuts.
1412
1413        Returns
1414        -------
1415        cut_coords : list of float
1416            xyz world coordinates of cuts.
1417
1418        """
1419        if cut_coords is None:
1420            if img is None or img is False:
1421                cut_coords = (0, 0, 0)
1422            else:
1423                cut_coords = find_xyz_cut_coords(
1424                    img, activation_threshold=threshold)
1425            cut_coords = [cut_coords['xyz'.find(c)]
1426                          for c in sorted(cls._cut_displayed)]
1427
1428        return cut_coords
1429
1430    def _find_initial_axes_coord(self, index):
1431        """Find coordinates for initial axes placement for xyz cuts.
1432
1433        Parameters
1434        ----------
1435        index : int
1436            Index corresponding to current cut 'x', 'y' or 'z'.
1437
1438        Returns
1439        -------
1440        [coord1, coord2, coord3, coord4] : list of int
1441            x0, y0, x1, y1 coordinates used by matplotlib
1442            to position axes in figure.
1443
1444        """
1445        rect_x0, rect_y0, rect_x1, rect_y1 = self.rect
1446
1447        if index == 0:
1448                coord1 = rect_x1 - rect_x0
1449                coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0
1450                coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0
1451                coord4 = rect_y1 - rect_y0
1452        elif index == 1:
1453                coord1 = 0.5 * (rect_x1 - rect_x0) + rect_x0
1454                coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0
1455                coord3 = rect_x1 - rect_x0
1456                coord4 = rect_y1 - rect_y0
1457        elif index == 2:
1458                coord1 = rect_x1 - rect_x0
1459                coord2 = rect_y1 - rect_y0
1460                coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0
1461                coord4 = 0.5 * (rect_y1 - rect_y0) + rect_y0
1462        return [coord1, coord2, coord3, coord4]
1463
1464    def _init_axes(self, **kwargs):
1465        """Initializes and places axes for display of 'xyz' cuts.
1466
1467        Parameters
1468        ----------
1469        kwargs :
1470            additional arguments to pass to self._axes_class
1471
1472        """
1473        cut_coords = self.cut_coords
1474        if len(cut_coords) != len(self._cut_displayed):
1475            raise ValueError('The number cut_coords passed does not'
1476                             ' match the display_mode')
1477
1478        facecolor = 'k' if self._black_bg else 'w'
1479
1480        self.axes = dict()
1481        for index, direction in enumerate(self._cut_displayed):
1482            fh = self.frame_axes.get_figure()
1483            axes_coords = self._find_initial_axes_coord(index)
1484            ax = fh.add_axes(axes_coords, aspect='equal')
1485
1486            if LooseVersion(matplotlib.__version__) >= LooseVersion("1.6"):
1487                ax.set_facecolor(facecolor)
1488            else:
1489                ax.set_axis_bgcolor(facecolor)
1490
1491            ax.axis('off')
1492            coord = self.cut_coords[
1493                sorted(self._cut_displayed).index(direction)]
1494            display_ax = self._axes_class(ax, direction, coord, **kwargs)
1495            self.axes[direction] = display_ax
1496            ax.set_axes_locator(self._locator)
1497
1498    def _adjust_width_height(self, width_dict, height_dict,
1499                             rect_x0, rect_y0, rect_x1, rect_y1):
1500        """Adjusts absolute image width and height to ratios.
1501
1502        Parameters
1503        ----------
1504        width_dict : dict
1505            Width of image cuts displayed in axes.
1506
1507        height_dict : dict
1508            Height of image cuts displayed in axes.
1509
1510        rect_x0, rect_y0, rect_x1, rect_y1 : float
1511            Matplotlib figure boundaries.
1512
1513        Returns
1514        -------
1515        width_dict : dict
1516            Width ratios of image cuts for optimal positioning of axes.
1517
1518        height_dict : dict
1519            Height ratios of image cuts for optimal positioning of axes.
1520
1521        """
1522        total_height = 0
1523        total_width = 0
1524
1525        if 'y' in self.axes:
1526            ax = self.axes['y'].ax
1527            total_height = total_height + height_dict[ax]
1528            total_width = total_width + width_dict[ax]
1529
1530        if 'x' in self.axes:
1531            ax = self.axes['x'].ax
1532            total_width = total_width + width_dict[ax]
1533
1534        if 'z' in self.axes:
1535            ax = self.axes['z'].ax
1536            total_height = total_height + height_dict[ax]
1537
1538        for ax, width in width_dict.items():
1539            width_dict[ax] = width / total_width * (rect_x1 - rect_x0)
1540
1541        for ax, height in height_dict.items():
1542            height_dict[ax] = height / total_height * (rect_y1 - rect_y0)
1543
1544        return (width_dict, height_dict)
1545
1546    def _find_axes_coord(self, rel_width_dict, rel_height_dict,
1547                         rect_x0, rect_y0, rect_x1, rect_y1):
1548        """"Find coordinates for initial axes placement for xyz cuts.
1549
1550        Parameters
1551        ----------
1552        rel_width_dict : dict
1553            Width ratios of image cuts for optimal positioning of axes.
1554
1555        rel_height_dict : dict
1556            Height ratios of image cuts for optimal positioning of axes.
1557
1558        rect_x0, rect_y0, rect_x1, rect_y1 : float
1559            Matplotlib figure boundaries.
1560
1561        Returns
1562        -------
1563        coord1, coord2, coord3, coord4 : dict
1564            x0, y0, x1, y1 coordinates per axes used by matplotlib
1565            to position axes in figure.
1566
1567        """
1568        coord1 = dict()
1569        coord2 = dict()
1570        coord3 = dict()
1571        coord4 = dict()
1572
1573        if 'y' in self.axes:
1574            ax = self.axes['y'].ax
1575            coord1[ax] = rect_x0
1576            coord2[ax] = (rect_y1) - rel_height_dict[ax]
1577            coord3[ax] = rect_x0 + rel_width_dict[ax]
1578            coord4[ax] = rect_y1
1579
1580        if 'x' in self.axes:
1581            ax = self.axes['x'].ax
1582            coord1[ax] = (rect_x1) - rel_width_dict[ax]
1583            coord2[ax] = (rect_y1) - rel_height_dict[ax]
1584            coord3[ax] = rect_x1
1585            coord4[ax] = rect_y1
1586
1587        if 'z' in self.axes:
1588            ax = self.axes['z'].ax
1589            coord1[ax] = rect_x0
1590            coord2[ax] = rect_y0
1591            coord3[ax] = rect_x0 + rel_width_dict[ax]
1592            coord4[ax] = rect_y0 + rel_height_dict[ax]
1593
1594        return(coord1, coord2, coord3, coord4)
1595
1596    def _locator(self, axes, renderer):
1597        """ The locator function used by matplotlib to position axes.
1598        Here we put the logic used to adjust the size of the axes.
1599
1600        """
1601        rect_x0, rect_y0, rect_x1, rect_y1 = self.rect
1602
1603        # image width and height
1604        width_dict = dict()
1605        height_dict = dict()
1606
1607        # A dummy axes, for the situation in which we are not plotting
1608        # all three (x, y, z) cuts
1609        dummy_ax = self._axes_class(None, None, None)
1610        width_dict[dummy_ax.ax] = 0
1611        height_dict[dummy_ax.ax] = 0
1612        display_ax_dict = self.axes
1613
1614        if self._colorbar:
1615            adjusted_width = self._colorbar_width / len(self.axes)
1616            right_margin = self._colorbar_margin['right'] / len(self.axes)
1617            ticks_margin = self._colorbar_margin['left'] / len(self.axes)
1618            rect_x1 = rect_x1 - (adjusted_width + ticks_margin + right_margin)
1619
1620        for display_ax in display_ax_dict.values():
1621            bounds = display_ax.get_object_bounds()
1622            if not bounds:
1623                # This happens if the call to _map_show was not
1624                # successful. As it happens asynchronously (during a
1625                # refresh of the figure) we capture the problem and
1626                # ignore it: it only adds a non informative traceback
1627                bounds = [0, 1, 0, 1]
1628            xmin, xmax, ymin, ymax = bounds
1629            width_dict[display_ax.ax] = (xmax - xmin)
1630            height_dict[display_ax.ax] = (ymax - ymin)
1631
1632        # relative image height and width
1633        rel_width_dict, rel_height_dict = self._adjust_width_height(
1634                width_dict, height_dict,
1635                rect_x0, rect_y0, rect_x1, rect_y1)
1636
1637        direction_ax = []
1638        for d in self._cut_displayed:
1639            direction_ax.append(display_ax_dict.get(d, dummy_ax).ax)
1640
1641        coord1, coord2, coord3, coord4 = self._find_axes_coord(
1642                rel_width_dict, rel_height_dict,
1643                rect_x0, rect_y0, rect_x1, rect_y1)
1644
1645        return transforms.Bbox([[coord1[axes], coord2[axes]],
1646                               [coord3[axes], coord4[axes]]])
1647
1648    def draw_cross(self, cut_coords=None, **kwargs):
1649        """Draw a crossbar on the plot to show where the cut is performed.
1650
1651        Parameters
1652        ----------
1653        cut_coords : 3-tuple of floats, optional
1654            The position of the cross to draw. If none is passed, the
1655            ortho_slicer's cut coordinates are used.
1656
1657        kwargs :
1658            Extra keyword arguments are passed to axhline
1659
1660        """
1661        if cut_coords is None:
1662            cut_coords = self.cut_coords
1663        coords = dict()
1664        for direction in 'xyz':
1665            coord_ = None
1666            if direction in self._cut_displayed:
1667                sorted_cuts = sorted(self._cut_displayed)
1668                index = sorted_cuts.index(direction)
1669                coord_ = cut_coords[index]
1670            coords[direction] = coord_
1671        x, y, z = coords['x'], coords['y'], coords['z']
1672
1673        kwargs = kwargs.copy()
1674        if 'color' not in kwargs:
1675            try:
1676                kwargs['color'] = '.8' if self._black_bg else 'k'
1677            except KeyError:
1678                pass
1679
1680        if 'y' in self.axes:
1681            ax = self.axes['y'].ax
1682            if x is not None:
1683                ax.axvline(x, **kwargs)
1684            if z is not None:
1685                ax.axhline(z, **kwargs)
1686
1687        if 'x' in self.axes:
1688            ax = self.axes['x'].ax
1689            if y is not None:
1690                ax.axvline(y, **kwargs)
1691            if z is not None:
1692                ax.axhline(z, **kwargs)
1693
1694        if 'z' in self.axes:
1695            ax = self.axes['z'].ax
1696            if x is not None:
1697                ax.axvline(x, **kwargs)
1698            if y is not None:
1699                ax.axhline(y, **kwargs)
1700
1701###############################################################################
1702# class BaseStackedSlicer
1703###############################################################################
1704
1705class BaseStackedSlicer(BaseSlicer):
1706    """ A class to create linked axes for plotting stacked
1707    cuts of 2D maps.
1708
1709    Attributes
1710    ----------
1711    axes : dictionary of axes
1712        The axes used to plot each view.
1713
1714    frame_axes : axes
1715        The axes framing the whole set of views.
1716
1717    Notes
1718    -----
1719    The extent of the different axes are adjusted to fit the data
1720    best in the viewing area.
1721
1722    """
1723    @classmethod
1724    def find_cut_coords(cls, img=None, threshold=None, cut_coords=None):
1725        "Instantiate the slicer and find cut coordinates"
1726        if cut_coords is None:
1727            cut_coords = 7
1728
1729        if img is None or img is False:
1730            bounds = ((-40, 40), (-30, 30), (-30, 75))
1731            lower, upper = bounds['xyz'.index(cls._direction)]
1732            cut_coords = np.linspace(lower, upper, cut_coords).tolist()
1733        else:
1734            if (not isinstance(cut_coords, collections.abc.Sequence) and
1735                    isinstance(cut_coords, numbers.Number)):
1736                cut_coords = find_cut_slices(img,
1737                                             direction=cls._direction,
1738                                             n_cuts=cut_coords)
1739
1740        return cut_coords
1741
1742    def _init_axes(self, **kwargs):
1743        x0, y0, x1, y1 = self.rect
1744        # Create our axes:
1745        self.axes = dict()
1746        fraction = 1. / len(self.cut_coords)
1747        for index, coord in enumerate(self.cut_coords):
1748            coord = float(coord)
1749            fh = self.frame_axes.get_figure()
1750            ax = fh.add_axes([fraction * index * (x1 - x0) + x0, y0,
1751                              fraction * (x1 - x0), y1 - y0])
1752            ax.axis('off')
1753            display_ax = self._axes_class(ax, self._direction,
1754                                          coord, **kwargs)
1755            self.axes[coord] = display_ax
1756            ax.set_axes_locator(self._locator)
1757
1758        if self._black_bg:
1759            for ax in self.axes.values():
1760                ax.ax.imshow(np.zeros((2, 2, 3)),
1761                             extent=[-5000, 5000, -5000, 5000],
1762                             zorder=-500, aspect='equal')
1763
1764            # To have a black background in PDF, we need to create a
1765            # patch in black for the background
1766            self.frame_axes.imshow(np.zeros((2, 2, 3)),
1767                                   extent=[-5000, 5000, -5000, 5000],
1768                                   zorder=-500, aspect='auto')
1769            self.frame_axes.set_zorder(-1000)
1770
1771    def _locator(self, axes, renderer):
1772        """ The locator function used by matplotlib to position axes.
1773        Here we put the logic used to adjust the size of the axes.
1774
1775        """
1776        x0, y0, x1, y1 = self.rect
1777        width_dict = dict()
1778        display_ax_dict = self.axes
1779
1780        if self._colorbar:
1781            adjusted_width = self._colorbar_width / len(self.axes)
1782            right_margin = self._colorbar_margin['right'] / len(self.axes)
1783            ticks_margin = self._colorbar_margin['left'] / len(self.axes)
1784            x1 = x1 - (adjusted_width + right_margin + ticks_margin)
1785
1786        for display_ax in display_ax_dict.values():
1787            bounds = display_ax.get_object_bounds()
1788            if not bounds:
1789                # This happens if the call to _map_show was not
1790                # successful. As it happens asynchronously (during a
1791                # refresh of the figure) we capture the problem and
1792                # ignore it: it only adds a non informative traceback
1793                bounds = [0, 1, 0, 1]
1794            xmin, xmax, ymin, ymax = bounds
1795            width_dict[display_ax.ax] = (xmax - xmin)
1796        total_width = float(sum(width_dict.values()))
1797        for ax, width in width_dict.items():
1798            width_dict[ax] = width / total_width * (x1 - x0)
1799        left_dict = dict()
1800        left = float(x0)
1801        for coord, display_ax in display_ax_dict.items():
1802            left_dict[display_ax.ax] = left
1803            this_width = width_dict[display_ax.ax]
1804            left += this_width
1805        return transforms.Bbox([[left_dict[axes], y0],
1806                                [left_dict[axes] + width_dict[axes], y1]])
1807
1808    def draw_cross(self, cut_coords=None, **kwargs):
1809        """ Draw a crossbar on the plot to show where the cut is
1810        performed.
1811
1812        Parameters
1813        ----------
1814        cut_coords : 3-tuple of floats, optional
1815            The position of the cross to draw. If none is passed, the
1816            ortho_slicer's cut coordinates are used.
1817
1818        kwargs :
1819            Extra keyword arguments are passed to axhline
1820
1821        """
1822        return
1823
1824
1825class XSlicer(BaseStackedSlicer):
1826    _direction = 'x'
1827    _default_figsize = [2.6, 2.3]
1828
1829
1830class YSlicer(BaseStackedSlicer):
1831    _direction = 'y'
1832    _default_figsize = [2.2, 2.3]
1833
1834
1835class ZSlicer(BaseStackedSlicer):
1836    _direction = 'z'
1837    _default_figsize = [2.2, 2.3]
1838
1839
1840class XZSlicer(OrthoSlicer):
1841    _cut_displayed = 'xz'
1842
1843
1844class YXSlicer(OrthoSlicer):
1845    _cut_displayed = 'yx'
1846
1847
1848class YZSlicer(OrthoSlicer):
1849    _cut_displayed = 'yz'
1850
1851
1852class MosaicSlicer(BaseSlicer):
1853    """ A class to create 3 axes for plotting cuts of 3D maps,
1854    in multiple rows and columns.
1855
1856    Attributes
1857    ----------
1858    axes : dictionary of axes
1859        The 3 axes used to plot multiple views.
1860
1861    frame_axes : axes
1862        The axes framing the whole set of views.
1863
1864    """
1865    _cut_displayed = 'yxz'
1866    _axes_class = CutAxes
1867    _default_figsize = [11.1, 7.2]
1868
1869    @classmethod
1870    def find_cut_coords(cls, img=None, threshold=None, cut_coords=None):
1871        """Instantiate the slicer and find cut coordinates for mosaic plotting.
1872
1873        Parameters
1874        ----------
1875        img : 3D Nifti1Image, optional
1876            The brain image.
1877
1878        threshold : float, optional
1879            The lower threshold to the positive activation. If None, the
1880            activation threshold is computed using the 80% percentile of
1881            the absolute value of the map.
1882
1883        cut_coords : list/tuple of 3 floats, integer, optional
1884            xyz world coordinates of cuts. If cut_coords
1885            are not provided, 7 coordinates of cuts are automatically
1886            calculated.
1887
1888        Returns
1889        -------
1890        cut_coords : dict
1891            xyz world coordinates of cuts in a direction. Each key
1892            denotes the direction.
1893        """
1894        if cut_coords is None:
1895            cut_coords = 7
1896
1897        if (not isinstance(cut_coords, collections.abc.Sequence) and
1898                isinstance(cut_coords, numbers.Number)):
1899            cut_coords = [cut_coords] * 3
1900            cut_coords = cls._find_cut_coords(img, cut_coords,
1901                                              cls._cut_displayed)
1902        else:
1903            if len(cut_coords) != len(cls._cut_displayed):
1904                raise ValueError('The number cut_coords passed does not'
1905                                 ' match the display_mode. Mosaic plotting '
1906                                 'expects tuple of length 3.' )
1907            cut_coords = [cut_coords['xyz'.find(c)]
1908                          for c in sorted(cls._cut_displayed)]
1909            cut_coords = cls._find_cut_coords(img, cut_coords,
1910                                              cls._cut_displayed)
1911        return cut_coords
1912
1913    @staticmethod
1914    def _find_cut_coords(img, cut_coords, cut_displayed):
1915        """ Find slicing positions along a given axis.
1916
1917            Helper function to find_cut_coords.
1918
1919        Parameters
1920        ----------
1921        img : 3D Nifti1Image
1922            The brain image.
1923
1924        cut_coords : list/tuple of 3 floats, integer, optional
1925            xyz world coordinates of cuts.
1926
1927        cut_displayed : str
1928            Sectional directions 'yxz'
1929
1930        Returns
1931        -------
1932        cut_coords : 1D array of length specified in n_cuts
1933            The computed cut_coords.
1934        """
1935        coords = dict()
1936        if img is None or img is False:
1937            bounds = ((-40, 40), (-30, 30), (-30, 75))
1938            for direction, n_cuts in zip(sorted(cut_displayed),
1939                                         cut_coords):
1940                lower, upper = bounds['xyz'.index(direction)]
1941                coords[direction] = np.linspace(lower, upper,
1942                                                n_cuts).tolist()
1943        else:
1944            for direction, n_cuts in zip(sorted(cut_displayed),
1945                                         cut_coords):
1946                coords[direction] = find_cut_slices(img, direction=direction,
1947                                                    n_cuts=n_cuts)
1948        return coords
1949
1950    def _init_axes(self, **kwargs):
1951        """Initializes and places axes for display of 'xyz' multiple cuts.
1952
1953        Parameters
1954        ----------
1955        kwargs:
1956            additional arguments to pass to self._axes_class
1957
1958        """
1959        if not isinstance(self.cut_coords, dict):
1960            self.cut_coords = self.find_cut_coords(cut_coords=self.cut_coords)
1961
1962        if len(self.cut_coords) != len(self._cut_displayed):
1963            raise ValueError('The number cut_coords passed does not'
1964                             ' match the mosaic mode')
1965        x0, y0, x1, y1 = self.rect
1966
1967        # Create our axes:
1968        self.axes = dict()
1969        # portions for main axes
1970        fraction = y1 / len(self.cut_coords)
1971        height = fraction
1972        for index, direction in enumerate(self._cut_displayed):
1973            coords = self.cut_coords[direction]
1974            # portions allotment for each of 'x', 'y', 'z' coordinate
1975            fraction_c = 1. / len(coords)
1976            fh = self.frame_axes.get_figure()
1977            indices = [x0, fraction * index * (y1 - y0) + y0,
1978                       x1, fraction * (y1 - y0)]
1979            ax = fh.add_axes(indices)
1980            ax.axis('off')
1981            this_x0, this_y0, this_x1, this_y1 = indices
1982            for index_c, coord in enumerate(coords):
1983                coord = float(coord)
1984                fh_c = ax.get_figure()
1985                # indices for each sub axes within main axes
1986                indices = [fraction_c * index_c * (this_x1 - this_x0) + this_x0,
1987                           this_y0,
1988                           fraction_c * (this_x1 - this_x0),
1989                           height]
1990                ax = fh_c.add_axes(indices)
1991                ax.axis('off')
1992                display_ax = self._axes_class(ax, direction,
1993                                              coord, **kwargs)
1994                self.axes[(direction, coord)] = display_ax
1995                ax.set_axes_locator(self._locator)
1996
1997    def _locator(self, axes, renderer):
1998        """ The locator function used by matplotlib to position axes.
1999            Here we put the logic used to adjust the size of the axes.
2000        """
2001        x0, y0, x1, y1 = self.rect
2002        display_ax_dict = self.axes
2003
2004        if self._colorbar:
2005            adjusted_width = self._colorbar_width / len(self.axes)
2006            right_margin = self._colorbar_margin['right'] / len(self.axes)
2007            ticks_margin = self._colorbar_margin['left'] / len(self.axes)
2008            x1 = x1 - (adjusted_width + right_margin + ticks_margin)
2009
2010        # capture widths for each axes for anchoring Bbox
2011        width_dict = dict()
2012        for direction in self._cut_displayed:
2013            this_width = dict()
2014            for display_ax in display_ax_dict.values():
2015                if direction == display_ax.direction:
2016                    bounds = display_ax.get_object_bounds()
2017                    if not bounds:
2018                        # This happens if the call to _map_show was not
2019                        # successful. As it happens asynchronously (during a
2020                        # refresh of the figure) we capture the problem and
2021                        # ignore it: it only adds a non informative traceback
2022                        bounds = [0, 1, 0, 1]
2023                    xmin, xmax, ymin, ymax = bounds
2024                    this_width[display_ax.ax] = (xmax - xmin)
2025            total_width = float(sum(this_width.values()))
2026            for ax, w in this_width.items():
2027                width_dict[ax] = w / total_width * (x1 - x0)
2028
2029        left_dict = dict()
2030        # bottom positions in Bbox according to cuts
2031        bottom_dict = dict()
2032        # fraction is divided by the cut directions 'y', 'x', 'z'
2033        fraction = y1 / len(self._cut_displayed)
2034        height_dict = dict()
2035        for index, direction in enumerate(self._cut_displayed):
2036            left = float(x0)
2037            this_height = fraction + fraction * index
2038            for coord, display_ax in display_ax_dict.items():
2039                if direction == display_ax.direction:
2040                    left_dict[display_ax.ax] = left
2041                    this_width = width_dict[display_ax.ax]
2042                    left += this_width
2043                    bottom_dict[display_ax.ax] = fraction * index * (y1 - y0)
2044                    height_dict[display_ax.ax] = this_height
2045        return transforms.Bbox([[left_dict[axes], bottom_dict[axes]],
2046                                [left_dict[axes] + width_dict[axes],
2047                                 height_dict[axes]]])
2048
2049
2050    def draw_cross(self, cut_coords=None, **kwargs):
2051        """ Draw a crossbar on the plot to show where the cut is
2052        performed.
2053
2054        Parameters
2055        ----------
2056        cut_coords: 3-tuple of floats, optional
2057            The position of the cross to draw. If none is passed, the
2058            ortho_slicer's cut coordinates are used.
2059        kwargs:
2060            Extra keyword arguments are passed to axhline
2061        """
2062        return
2063
2064
2065SLICERS = dict(ortho=OrthoSlicer,
2066               tiled=TiledSlicer,
2067               mosaic=MosaicSlicer,
2068               xz=XZSlicer,
2069               yz=YZSlicer,
2070               yx=YXSlicer,
2071               x=XSlicer,
2072               y=YSlicer,
2073               z=ZSlicer)
2074
2075
2076class OrthoProjector(OrthoSlicer):
2077    """A class to create linked axes for plotting orthogonal projections
2078    of 3D maps.
2079
2080    """
2081    _axes_class = GlassBrainAxes
2082
2083    @classmethod
2084    def find_cut_coords(cls, img=None, threshold=None, cut_coords=None):
2085        return (None, ) * len(cls._cut_displayed)
2086
2087    def draw_cross(self, cut_coords=None, **kwargs):
2088        # It does not make sense to draw crosses for the position of
2089        # the cuts since we are taking the max along one axis
2090        pass
2091
2092    def add_graph(self, adjacency_matrix, node_coords,
2093                  node_color='auto', node_size=50,
2094                  edge_cmap=cm.bwr,
2095                  edge_vmin=None, edge_vmax=None,
2096                  edge_threshold=None,
2097                  edge_kwargs=None, node_kwargs=None, colorbar=False,
2098                  ):
2099        """Plot undirected graph on each of the axes
2100
2101        Parameters
2102        ----------
2103        adjacency_matrix : numpy array of shape (n, n)
2104            Represents the edges strengths of the graph.
2105            The matrix can be symmetric which will result in
2106            an undirected graph, or not symmetric which will
2107            result in a directed graph.
2108
2109        node_coords : numpy array_like of shape (n, 3)
2110            3d coordinates of the graph nodes in world space.
2111
2112        node_color : color or sequence of colors, optional
2113            Color(s) of the nodes. Default='auto'.
2114
2115        node_size : scalar or array_like, optional
2116            Size(s) of the nodes in points^2. Default=50.
2117
2118        edge_cmap : colormap, optional
2119            Colormap used for representing the strength of the edges.
2120            Default=cm.bwr.
2121
2122        edge_vmin, edge_vmax : float, optional
2123            If not None, either or both of these values will be used to
2124            as the minimum and maximum values to color edges. If None are
2125            supplied the maximum absolute value within the given threshold
2126            will be used as minimum (multiplied by -1) and maximum
2127            coloring levels.
2128
2129        edge_threshold : str or number, optional
2130            If it is a number only the edges with a value greater than
2131            edge_threshold will be shown.
2132            If it is a string it must finish with a percent sign,
2133            e.g. "25.3%", and only the edges with a abs(value) above
2134            the given percentile will be shown.
2135
2136        edge_kwargs : dict, optional
2137            Will be passed as kwargs for each edge matlotlib Line2D.
2138
2139        node_kwargs : dict
2140            Will be passed as kwargs to the plt.scatter call that plots all
2141            the nodes in one go.
2142
2143        """
2144        # set defaults
2145        if edge_kwargs is None:
2146            edge_kwargs = {}
2147        if node_kwargs is None:
2148            node_kwargs = {}
2149        if isinstance(node_color, str) and node_color == 'auto':
2150            nb_nodes = len(node_coords)
2151            node_color = mpl_cm.Set2(np.linspace(0, 1, nb_nodes))
2152        node_coords = np.asarray(node_coords)
2153
2154        # decompress input matrix if sparse
2155        if sparse.issparse(adjacency_matrix):
2156            adjacency_matrix = adjacency_matrix.toarray()
2157
2158        # make the lines below well-behaved
2159        adjacency_matrix = np.nan_to_num(adjacency_matrix)
2160
2161        # safety checks
2162        if 's' in node_kwargs:
2163            raise ValueError("Please use 'node_size' and not 'node_kwargs' "
2164                             "to specify node sizes")
2165        if 'c' in node_kwargs:
2166            raise ValueError("Please use 'node_color' and not 'node_kwargs' "
2167                             "to specify node colors")
2168
2169        adjacency_matrix_shape = adjacency_matrix.shape
2170        if (len(adjacency_matrix_shape) != 2 or
2171                adjacency_matrix_shape[0] != adjacency_matrix_shape[1]):
2172            raise ValueError(
2173                "'adjacency_matrix' is supposed to have shape (n, n)."
2174                ' Its shape was {0}'.format(adjacency_matrix_shape))
2175
2176        node_coords_shape = node_coords.shape
2177        if len(node_coords_shape) != 2 or node_coords_shape[1] != 3:
2178            message = (
2179                "Invalid shape for 'node_coords'. You passed an "
2180                "'adjacency_matrix' of shape {0} therefore "
2181                "'node_coords' should be a array with shape ({0[0]}, 3) "
2182                'while its shape was {1}').format(adjacency_matrix_shape,
2183                                                  node_coords_shape)
2184
2185            raise ValueError(message)
2186
2187        if isinstance(node_color, (list, np.ndarray)) and len(node_color) != 1:
2188            if len(node_color) != node_coords_shape[0]:
2189                raise ValueError(
2190                    "Mismatch between the number of nodes ({0}) "
2191                    "and and the number of node colors ({1})."
2192                    .format(node_coords_shape[0], len(node_color)))
2193
2194        if node_coords_shape[0] != adjacency_matrix_shape[0]:
2195            raise ValueError(
2196                "Shape mismatch between 'adjacency_matrix' "
2197                "and 'node_coords'"
2198                "'adjacency_matrix' shape is {0}, 'node_coords' shape is {1}"
2199                .format(adjacency_matrix_shape, node_coords_shape))
2200
2201        # If the adjacency matrix is not symmetric, give a warning
2202        symmetric = True
2203        if not np.allclose(adjacency_matrix, adjacency_matrix.T, rtol=1e-3):
2204            symmetric = False
2205            warnings.warn(("'adjacency_matrix' is not symmetric. "
2206                           "A directed graph will be plotted."))
2207
2208        # For a masked array, masked values are replaced with zeros
2209        if hasattr(adjacency_matrix, 'mask'):
2210            if not (adjacency_matrix.mask == adjacency_matrix.mask.T).all():
2211                symmetric = False
2212                warnings.warn(("'adjacency_matrix' was masked with "
2213                               "a non symmetric mask. A directed "
2214                               "graph will be plotted."))
2215            adjacency_matrix = adjacency_matrix.filled(0)
2216
2217        if edge_threshold is not None:
2218            if symmetric:
2219                # Keep a percentile of edges with the highest absolute
2220                # values, so only need to look at the covariance
2221                # coefficients below the diagonal
2222                lower_diagonal_indices = np.tril_indices_from(adjacency_matrix,
2223                                                              k=-1)
2224                lower_diagonal_values = adjacency_matrix[
2225                    lower_diagonal_indices]
2226                edge_threshold = _utils.param_validation.check_threshold(
2227                    edge_threshold, np.abs(lower_diagonal_values),
2228                    stats.scoreatpercentile, 'edge_threshold')
2229            else:
2230                edge_threshold = _utils.param_validation.check_threshold(
2231                    edge_threshold, np.abs(adjacency_matrix.ravel()),
2232                    stats.scoreatpercentile, 'edge_threshold')
2233
2234            adjacency_matrix = adjacency_matrix.copy()
2235            threshold_mask = np.abs(adjacency_matrix) < edge_threshold
2236            adjacency_matrix[threshold_mask] = 0
2237
2238        if symmetric:
2239            lower_triangular_adjacency_matrix = np.tril(adjacency_matrix, k=-1)
2240            non_zero_indices = lower_triangular_adjacency_matrix.nonzero()
2241        else:
2242            non_zero_indices = adjacency_matrix.nonzero()
2243
2244        line_coords = [node_coords[list(index)]
2245                       for index in zip(*non_zero_indices)]
2246
2247        adjacency_matrix_values = adjacency_matrix[non_zero_indices]
2248        for ax in self.axes.values():
2249            ax._add_markers(node_coords, node_color, node_size, **node_kwargs)
2250            if line_coords:
2251                ax._add_lines(line_coords, adjacency_matrix_values, edge_cmap,
2252                              vmin=edge_vmin, vmax=edge_vmax, directed=(not symmetric),
2253                              **edge_kwargs)
2254            # To obtain the brain left view, we simply invert the x axis
2255            if ax.direction == 'l' and not (ax.ax.get_xlim()[0] > ax.ax.get_xlim()[1]):
2256                ax.ax.invert_xaxis()
2257
2258        if colorbar:
2259            self._colorbar = colorbar
2260            self._show_colorbar(ax.cmap, ax.norm, threshold=edge_threshold)
2261
2262        plt.draw_if_interactive()
2263
2264
2265class XProjector(OrthoProjector):
2266    _cut_displayed = 'x'
2267    _default_figsize = [2.6, 2.3]
2268
2269
2270class YProjector(OrthoProjector):
2271    _cut_displayed = 'y'
2272    _default_figsize = [2.2, 2.3]
2273
2274
2275class ZProjector(OrthoProjector):
2276    _cut_displayed = 'z'
2277    _default_figsize = [2.2, 2.3]
2278
2279
2280class XZProjector(OrthoProjector):
2281    _cut_displayed = 'xz'
2282
2283
2284class YXProjector(OrthoProjector):
2285    _cut_displayed = 'yx'
2286
2287
2288class YZProjector(OrthoProjector):
2289    _cut_displayed = 'yz'
2290
2291
2292class LYRZProjector(OrthoProjector):
2293    _cut_displayed = 'lyrz'
2294
2295
2296class LZRYProjector(OrthoProjector):
2297    _cut_displayed = 'lzry'
2298
2299
2300class LZRProjector(OrthoProjector):
2301    _cut_displayed = 'lzr'
2302
2303
2304class LYRProjector(OrthoProjector):
2305    _cut_displayed = 'lyr'
2306
2307
2308class LRProjector(OrthoProjector):
2309    _cut_displayed = 'lr'
2310
2311
2312class LProjector(OrthoProjector):
2313    _cut_displayed = 'l'
2314    _default_figsize = [2.6, 2.3]
2315
2316
2317class RProjector(OrthoProjector):
2318    _cut_displayed = 'r'
2319    _default_figsize = [2.6, 2.3]
2320
2321
2322PROJECTORS = dict(ortho=OrthoProjector,
2323                  xz=XZProjector,
2324                  yz=YZProjector,
2325                  yx=YXProjector,
2326                  x=XProjector,
2327                  y=YProjector,
2328                  z=ZProjector,
2329                  lzry=LZRYProjector,
2330                  lyrz=LYRZProjector,
2331                  lyr=LYRProjector,
2332                  lzr=LZRProjector,
2333                  lr=LRProjector,
2334                  l=LProjector,
2335                  r=RProjector)
2336
2337
2338def get_create_display_fun(display_mode, class_dict):
2339    try:
2340        return class_dict[display_mode].init_with_figure
2341    except KeyError:
2342        message = ('{0} is not a valid display_mode. '
2343                   'Valid options are {1}').format(
2344                        display_mode, sorted(class_dict.keys()))
2345        raise ValueError(message)
2346
2347
2348def get_slicer(display_mode):
2349    "Internal function to retrieve a slicer"
2350    return get_create_display_fun(display_mode, SLICERS)
2351
2352
2353def get_projector(display_mode):
2354    "Internal function to retrieve a projector"
2355    return get_create_display_fun(display_mode, PROJECTORS)
2356