1#=============================================================================== 2# Copyright (c) 2015, Max Zwiessele 3# All rights reserved. 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are met: 7# 8# * Redistributions of source code must retain the above copyright notice, this 9# list of conditions and the following disclaimer. 10# 11# * Redistributions in binary form must reproduce the above copyright notice, 12# this list of conditions and the following disclaimer in the documentation 13# and/or other materials provided with the distribution. 14# 15# * Neither the name of GPy.plotting.matplot_dep.plot_definitions nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29#=============================================================================== 30import numpy as np 31from matplotlib import pyplot as plt 32from ..abstract_plotting_library import AbstractPlottingLibrary 33from .. import Tango 34from . import defaults 35from matplotlib.colors import LinearSegmentedColormap 36from .controllers import ImshowController, ImAnnotateController 37import itertools 38from .util import legend_ontop 39 40class MatplotlibPlots(AbstractPlottingLibrary): 41 def __init__(self): 42 super(MatplotlibPlots, self).__init__() 43 self._defaults = defaults.__dict__ 44 45 def figure(self, rows=1, cols=1, gridspec_kwargs={}, tight_layout=True, **kwargs): 46 fig = plt.figure(tight_layout=tight_layout, **kwargs) 47 fig.rows = rows 48 fig.cols = cols 49 fig.gridspec = plt.GridSpec(rows, cols, **gridspec_kwargs) 50 return fig 51 52 def new_canvas(self, figure=None, row=1, col=1, projection='2d', xlabel=None, ylabel=None, zlabel=None, title=None, xlim=None, ylim=None, zlim=None, **kwargs): 53 if projection == '3d': 54 from mpl_toolkits.mplot3d import Axes3D 55 elif projection == '2d': 56 projection = None 57 if 'ax' in kwargs: 58 ax = kwargs.pop('ax') 59 else: 60 if figure is not None: 61 fig = figure 62 elif 'num' in kwargs and 'figsize' in kwargs: 63 fig = self.figure(num=kwargs.pop('num'), figsize=kwargs.pop('figsize')) 64 elif 'num' in kwargs: 65 fig = self.figure(num=kwargs.pop('num')) 66 elif 'figsize' in kwargs: 67 fig = self.figure(figsize=kwargs.pop('figsize')) 68 else: 69 fig = self.figure() 70 71 #if hasattr(fig, 'rows') and hasattr(fig, 'cols'): 72 ax = fig.add_subplot(fig.gridspec[row-1, col-1], projection=projection) 73 74 if xlim is not None: ax.set_xlim(xlim) 75 if ylim is not None: ax.set_ylim(ylim) 76 if xlabel is not None: ax.set_xlabel(xlabel) 77 if ylabel is not None: ax.set_ylabel(ylabel) 78 if title is not None: ax.set_title(title) 79 if projection == '3d': 80 if zlim is not None: ax.set_zlim(zlim) 81 if zlabel is not None: ax.set_zlabel(zlabel) 82 return ax, kwargs 83 84 def add_to_canvas(self, ax, plots, legend=False, title=None, **kwargs): 85 #ax.autoscale_view() 86 fontdict=dict(family='sans-serif', weight='light', size=9) 87 if legend is True: 88 ax.legend(*ax.get_legend_handles_labels()) 89 elif legend >= 1: 90 #ax.legend(prop=fontdict) 91 legend_ontop(ax, ncol=legend, fontdict=fontdict) 92 if title is not None: ax.figure.suptitle(title) 93 return plots 94 95 def show_canvas(self, ax, **kwargs): 96 ax.figure.canvas.draw() 97 return ax.figure 98 99 def scatter(self, ax, X, Y, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, marker='o', **kwargs): 100 if Z is not None: 101 return ax.scatter(X, Y, c=color, zs=Z, label=label, marker=marker, **kwargs) 102 return ax.scatter(X, Y, c=color, label=label, marker=marker, **kwargs) 103 104 def plot(self, ax, X, Y, Z=None, color=None, label=None, **kwargs): 105 if Z is not None: 106 return ax.plot(X, Y, color=color, zs=Z, label=label, **kwargs) 107 return ax.plot(X, Y, color=color, label=label, **kwargs) 108 109 def plot_axis_lines(self, ax, X, color=Tango.colorsHex['darkRed'], label=None, **kwargs): 110 from matplotlib import transforms 111 from matplotlib.path import Path 112 if 'marker' not in kwargs: 113 kwargs['marker'] = Path([[-.2,0.], [-.2,.5], [0.,1.], [.2,.5], [.2,0.], [-.2,0.]], 114 [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]) 115 if 'transform' not in kwargs: 116 if X.shape[1] == 1: 117 kwargs['transform'] = transforms.blended_transform_factory(ax.transData, ax.transAxes) 118 if X.shape[1] == 2: 119 return ax.scatter(X[:,0], X[:,1], ax.get_zlim()[0], c=color, label=label, **kwargs) 120 return ax.scatter(X, np.zeros_like(X), c=color, label=label, **kwargs) 121 122 def barplot(self, ax, x, height, width=0.8, bottom=0, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs): 123 if 'align' not in kwargs: 124 kwargs['align'] = 'center' 125 return ax.bar(x=x, height=height, width=width, 126 bottom=bottom, label=label, color=color, 127 **kwargs) 128 129 def xerrorbar(self, ax, X, Y, error, color=Tango.colorsHex['darkRed'], label=None, **kwargs): 130 if not('linestyle' in kwargs or 'ls' in kwargs): 131 kwargs['ls'] = 'none' 132 #if Z is not None: 133 # return ax.errorbar(X, Y, Z, xerr=error, ecolor=color, label=label, **kwargs) 134 return ax.errorbar(X, Y, xerr=error, ecolor=color, label=label, **kwargs) 135 136 def yerrorbar(self, ax, X, Y, error, color=Tango.colorsHex['darkRed'], label=None, **kwargs): 137 if not('linestyle' in kwargs or 'ls' in kwargs): 138 kwargs['ls'] = 'none' 139 #if Z is not None: 140 # return ax.errorbar(X, Y, Z, yerr=error, ecolor=color, label=label, **kwargs) 141 return ax.errorbar(X, Y, yerr=error, ecolor=color, label=label, **kwargs) 142 143 def imshow(self, ax, X, extent=None, label=None, vmin=None, vmax=None, **imshow_kwargs): 144 if 'origin' not in imshow_kwargs: 145 imshow_kwargs['origin'] = 'lower' 146 #xmin, xmax, ymin, ymax = extent 147 #xoffset, yoffset = (xmax - xmin) / (2. * X.shape[0]), (ymax - ymin) / (2. * X.shape[1]) 148 #xmin, xmax, ymin, ymax = extent = xmin-xoffset, xmax+xoffset, ymin-yoffset, ymax+yoffset 149 return ax.imshow(X, label=label, extent=extent, vmin=vmin, vmax=vmax, **imshow_kwargs) 150 151 def imshow_interact(self, ax, plot_function, extent, label=None, resolution=None, vmin=None, vmax=None, **imshow_kwargs): 152 if imshow_kwargs is None: imshow_kwargs = {} 153 if 'origin' not in imshow_kwargs: 154 imshow_kwargs['origin'] = 'lower' 155 return ImshowController(ax, plot_function, extent, resolution=resolution, vmin=vmin, vmax=vmax, **imshow_kwargs) 156 157 def annotation_heatmap(self, ax, X, annotation, extent=None, label=None, imshow_kwargs=None, **annotation_kwargs): 158 if imshow_kwargs is None: imshow_kwargs = {} 159 if 'origin' not in imshow_kwargs: 160 imshow_kwargs['origin'] = 'lower' 161 if ('ha' not in annotation_kwargs) and ('horizontalalignment' not in annotation_kwargs): 162 annotation_kwargs['ha'] = 'center' 163 if ('va' not in annotation_kwargs) and ('verticalalignment' not in annotation_kwargs): 164 annotation_kwargs['va'] = 'center' 165 imshow = self.imshow(ax, X, extent, label, **imshow_kwargs) 166 if extent is None: 167 extent = (0, X.shape[0], 0, X.shape[1]) 168 xmin, xmax, ymin, ymax = extent 169 xoffset, yoffset = (xmax - xmin) / (2. * X.shape[0]), (ymax - ymin) / (2. * X.shape[1]) 170 xlin = np.linspace(xmin, xmax, X.shape[0], endpoint=False) 171 ylin = np.linspace(ymin, ymax, X.shape[1], endpoint=False) 172 annotations = [] 173 for [i, x], [j, y] in itertools.product(enumerate(xlin), enumerate(ylin)): 174 annotations.append(ax.text(x+xoffset, y+yoffset, "{}".format(annotation[j, i]), **annotation_kwargs)) 175 return imshow, annotations 176 177 def annotation_heatmap_interact(self, ax, plot_function, extent, label=None, resolution=15, imshow_kwargs=None, **annotation_kwargs): 178 if imshow_kwargs is None: imshow_kwargs = {} 179 if 'origin' not in imshow_kwargs: 180 imshow_kwargs['origin'] = 'lower' 181 return ImAnnotateController(ax, plot_function, extent, resolution=resolution, imshow_kwargs=imshow_kwargs or {}, **annotation_kwargs) 182 183 def contour(self, ax, X, Y, C, levels=20, label=None, **kwargs): 184 return ax.contour(X, Y, C, levels=np.linspace(C.min(), C.max(), levels), label=label, **kwargs) 185 186 def surface(self, ax, X, Y, Z, color=None, label=None, **kwargs): 187 return ax.plot_surface(X, Y, Z, label=label, **kwargs) 188 189 def fill_between(self, ax, X, lower, upper, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs): 190 return ax.fill_between(X, lower, upper, facecolor=color, label=label, **kwargs) 191 192 def fill_gradient(self, canvas, X, percentiles, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs): 193 ax = canvas 194 plots = [] 195 196 if 'edgecolors' not in kwargs: 197 kwargs['edgecolors'] = 'none' 198 199 if 'facecolors' in kwargs: 200 color = kwargs.pop('facecolors') 201 202 if 'array' in kwargs: 203 array = kwargs.pop('array') 204 else: 205 array = 1.-np.abs(np.linspace(-.97, .97, len(percentiles)-1)) 206 207 if 'alpha' in kwargs: 208 alpha = kwargs.pop('alpha') 209 else: 210 alpha = .8 211 212 if 'cmap' in kwargs: 213 cmap = kwargs.pop('cmap') 214 else: 215 cmap = LinearSegmentedColormap.from_list('WhToColor', (color, color), N=array.size) 216 cmap._init() 217 cmap._lut[:-3, -1] = alpha*array 218 219 kwargs['facecolors'] = [cmap(i) for i in np.linspace(0,1,cmap.N)] 220 221 # pop where from kwargs 222 where = kwargs.pop('where') if 'where' in kwargs else None 223 # pop interpolate, which we actually do not do here! 224 if 'interpolate' in kwargs: kwargs.pop('interpolate') 225 226 def pairwise(iterable): 227 "s -> (s0,s1), (s1,s2), (s2, s3), ..." 228 from itertools import tee 229 #try: 230 # from itertools import izip as zip 231 #except ImportError: 232 # pass 233 a, b = tee(iterable) 234 next(b, None) 235 return zip(a, b) 236 237 polycol = [] 238 for y1, y2 in pairwise(percentiles): 239 try: 240 from matplotlib.cbook import contiguous_regions 241 except ImportError: 242 from matplotlib.mlab import contiguous_regions 243 # Handle united data, such as dates 244 ax._process_unit_info(xdata=X, ydata=y1) 245 ax._process_unit_info(ydata=y2) 246 # Convert the arrays so we can work with them 247 from numpy import ma 248 x = ma.masked_invalid(ax.convert_xunits(X)) 249 y1 = ma.masked_invalid(ax.convert_yunits(y1)) 250 y2 = ma.masked_invalid(ax.convert_yunits(y2)) 251 252 if y1.ndim == 0: 253 y1 = np.ones_like(x) * y1 254 if y2.ndim == 0: 255 y2 = np.ones_like(x) * y2 256 257 if where is None: 258 where = np.ones(len(x), np.bool) 259 else: 260 where = np.asarray(where, np.bool) 261 262 if not (x.shape == y1.shape == y2.shape == where.shape): 263 raise ValueError("Argument dimensions are incompatible") 264 265 from functools import reduce 266 mask = reduce(ma.mask_or, [ma.getmask(a) for a in (x, y1, y2)]) 267 if mask is not ma.nomask: 268 where &= ~mask 269 270 polys = [] 271 for ind0, ind1 in contiguous_regions(where): 272 xslice = x[ind0:ind1] 273 y1slice = y1[ind0:ind1] 274 y2slice = y2[ind0:ind1] 275 276 if not len(xslice): 277 continue 278 279 N = len(xslice) 280 p = np.zeros((2 * N + 2, 2), np.float) 281 282 # the purpose of the next two lines is for when y2 is a 283 # scalar like 0 and we want the fill to go all the way 284 # down to 0 even if none of the y1 sample points do 285 start = xslice[0], y2slice[0] 286 end = xslice[-1], y2slice[-1] 287 288 p[0] = start 289 p[N + 1] = end 290 291 p[1:N + 1, 0] = xslice 292 p[1:N + 1, 1] = y1slice 293 p[N + 2:, 0] = xslice[::-1] 294 p[N + 2:, 1] = y2slice[::-1] 295 296 polys.append(p) 297 polycol.extend(polys) 298 from matplotlib.collections import PolyCollection 299 if 'zorder' not in kwargs: 300 kwargs['zorder'] = 0 301 plots.append(PolyCollection(polycol, label=label, **kwargs)) 302 ax.add_collection(plots[-1], autolim=True) 303 ax.autoscale_view() 304 return plots 305