1#===============================================================================
2# Copyright (c) 2015, Max Zwiessele
3# All rights reserved.
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are met:
7#
8# * Redistributions of source code must retain the above copyright notice, this
9#   list of conditions and the following disclaimer.
10#
11# * Redistributions in binary form must reproduce the above copyright notice,
12#   this list of conditions and the following disclaimer in the documentation
13#   and/or other materials provided with the distribution.
14#
15# * Neither the name of GPy.plotting.matplot_dep.plot_definitions nor the names of its
16#   contributors may be used to endorse or promote products derived from
17#   this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29#===============================================================================
30import numpy as np
31from matplotlib import pyplot as plt
32from ..abstract_plotting_library import AbstractPlottingLibrary
33from .. import Tango
34from . import defaults
35from matplotlib.colors import LinearSegmentedColormap
36from .controllers import ImshowController, ImAnnotateController
37import itertools
38from .util import legend_ontop
39
40class MatplotlibPlots(AbstractPlottingLibrary):
41    def __init__(self):
42        super(MatplotlibPlots, self).__init__()
43        self._defaults = defaults.__dict__
44
45    def figure(self, rows=1, cols=1, gridspec_kwargs={}, tight_layout=True, **kwargs):
46        fig = plt.figure(tight_layout=tight_layout, **kwargs)
47        fig.rows = rows
48        fig.cols = cols
49        fig.gridspec = plt.GridSpec(rows, cols, **gridspec_kwargs)
50        return fig
51
52    def new_canvas(self, figure=None, row=1, col=1, projection='2d', xlabel=None, ylabel=None, zlabel=None, title=None, xlim=None, ylim=None, zlim=None, **kwargs):
53        if projection == '3d':
54            from mpl_toolkits.mplot3d import Axes3D
55        elif projection == '2d':
56            projection = None
57        if 'ax' in kwargs:
58            ax = kwargs.pop('ax')
59        else:
60            if figure is not None:
61                fig = figure
62            elif 'num' in kwargs and 'figsize' in kwargs:
63                fig = self.figure(num=kwargs.pop('num'), figsize=kwargs.pop('figsize'))
64            elif 'num' in kwargs:
65                fig = self.figure(num=kwargs.pop('num'))
66            elif 'figsize' in kwargs:
67                fig = self.figure(figsize=kwargs.pop('figsize'))
68            else:
69                fig = self.figure()
70
71            #if hasattr(fig, 'rows') and hasattr(fig, 'cols'):
72            ax = fig.add_subplot(fig.gridspec[row-1, col-1], projection=projection)
73
74        if xlim is not None: ax.set_xlim(xlim)
75        if ylim is not None: ax.set_ylim(ylim)
76        if xlabel is not None: ax.set_xlabel(xlabel)
77        if ylabel is not None: ax.set_ylabel(ylabel)
78        if title is not None: ax.set_title(title)
79        if projection == '3d':
80            if zlim is not None: ax.set_zlim(zlim)
81            if zlabel is not None: ax.set_zlabel(zlabel)
82        return ax, kwargs
83
84    def add_to_canvas(self, ax, plots, legend=False, title=None, **kwargs):
85        #ax.autoscale_view()
86        fontdict=dict(family='sans-serif', weight='light', size=9)
87        if legend is True:
88            ax.legend(*ax.get_legend_handles_labels())
89        elif legend >= 1:
90            #ax.legend(prop=fontdict)
91            legend_ontop(ax, ncol=legend, fontdict=fontdict)
92        if title is not None: ax.figure.suptitle(title)
93        return plots
94
95    def show_canvas(self, ax, **kwargs):
96        ax.figure.canvas.draw()
97        return ax.figure
98
99    def scatter(self, ax, X, Y, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, marker='o', **kwargs):
100        if Z is not None:
101            return ax.scatter(X, Y, c=color, zs=Z, label=label, marker=marker, **kwargs)
102        return ax.scatter(X, Y, c=color, label=label, marker=marker, **kwargs)
103
104    def plot(self, ax, X, Y, Z=None, color=None, label=None, **kwargs):
105        if Z is not None:
106            return ax.plot(X, Y, color=color, zs=Z, label=label, **kwargs)
107        return ax.plot(X, Y, color=color, label=label, **kwargs)
108
109    def plot_axis_lines(self, ax, X, color=Tango.colorsHex['darkRed'], label=None, **kwargs):
110        from matplotlib import transforms
111        from matplotlib.path import Path
112        if 'marker' not in kwargs:
113            kwargs['marker'] = Path([[-.2,0.],    [-.2,.5],    [0.,1.],    [.2,.5],     [.2,0.],     [-.2,0.]],
114                                    [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY])
115        if 'transform' not in kwargs:
116            if X.shape[1] == 1:
117                kwargs['transform'] = transforms.blended_transform_factory(ax.transData, ax.transAxes)
118        if X.shape[1] == 2:
119            return ax.scatter(X[:,0], X[:,1], ax.get_zlim()[0], c=color, label=label, **kwargs)
120        return ax.scatter(X, np.zeros_like(X), c=color, label=label, **kwargs)
121
122    def barplot(self, ax, x, height, width=0.8, bottom=0, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
123        if 'align' not in kwargs:
124            kwargs['align'] = 'center'
125        return ax.bar(x=x, height=height, width=width,
126               bottom=bottom, label=label, color=color,
127               **kwargs)
128
129    def xerrorbar(self, ax, X, Y, error, color=Tango.colorsHex['darkRed'], label=None, **kwargs):
130        if not('linestyle' in kwargs or 'ls' in kwargs):
131            kwargs['ls'] = 'none'
132        #if Z is not None:
133        #    return ax.errorbar(X, Y, Z, xerr=error, ecolor=color, label=label, **kwargs)
134        return ax.errorbar(X, Y, xerr=error, ecolor=color, label=label, **kwargs)
135
136    def yerrorbar(self, ax, X, Y, error, color=Tango.colorsHex['darkRed'], label=None, **kwargs):
137        if not('linestyle' in kwargs or 'ls' in kwargs):
138            kwargs['ls'] = 'none'
139        #if Z is not None:
140        #    return ax.errorbar(X, Y, Z, yerr=error, ecolor=color, label=label, **kwargs)
141        return ax.errorbar(X, Y, yerr=error, ecolor=color, label=label, **kwargs)
142
143    def imshow(self, ax, X, extent=None, label=None, vmin=None, vmax=None, **imshow_kwargs):
144        if 'origin' not in imshow_kwargs:
145            imshow_kwargs['origin'] = 'lower'
146        #xmin, xmax, ymin, ymax = extent
147        #xoffset, yoffset = (xmax - xmin) / (2. * X.shape[0]), (ymax - ymin) / (2. * X.shape[1])
148        #xmin, xmax, ymin, ymax = extent = xmin-xoffset, xmax+xoffset, ymin-yoffset, ymax+yoffset
149        return ax.imshow(X, label=label, extent=extent, vmin=vmin, vmax=vmax, **imshow_kwargs)
150
151    def imshow_interact(self, ax, plot_function, extent, label=None, resolution=None, vmin=None, vmax=None, **imshow_kwargs):
152        if imshow_kwargs is None: imshow_kwargs = {}
153        if 'origin' not in imshow_kwargs:
154            imshow_kwargs['origin'] = 'lower'
155        return ImshowController(ax, plot_function, extent, resolution=resolution, vmin=vmin, vmax=vmax, **imshow_kwargs)
156
157    def annotation_heatmap(self, ax, X, annotation, extent=None, label=None, imshow_kwargs=None, **annotation_kwargs):
158        if imshow_kwargs is None: imshow_kwargs = {}
159        if 'origin' not in imshow_kwargs:
160            imshow_kwargs['origin'] = 'lower'
161        if ('ha' not in annotation_kwargs) and ('horizontalalignment' not in annotation_kwargs):
162            annotation_kwargs['ha'] = 'center'
163        if ('va' not in annotation_kwargs) and ('verticalalignment' not in annotation_kwargs):
164            annotation_kwargs['va'] = 'center'
165        imshow = self.imshow(ax, X, extent, label, **imshow_kwargs)
166        if extent is None:
167            extent = (0, X.shape[0], 0, X.shape[1])
168        xmin, xmax, ymin, ymax = extent
169        xoffset, yoffset = (xmax - xmin) / (2. * X.shape[0]), (ymax - ymin) / (2. * X.shape[1])
170        xlin = np.linspace(xmin, xmax, X.shape[0], endpoint=False)
171        ylin = np.linspace(ymin, ymax, X.shape[1], endpoint=False)
172        annotations = []
173        for [i, x], [j, y] in itertools.product(enumerate(xlin), enumerate(ylin)):
174            annotations.append(ax.text(x+xoffset, y+yoffset, "{}".format(annotation[j, i]), **annotation_kwargs))
175        return imshow, annotations
176
177    def annotation_heatmap_interact(self, ax, plot_function, extent, label=None, resolution=15, imshow_kwargs=None, **annotation_kwargs):
178        if imshow_kwargs is None: imshow_kwargs = {}
179        if 'origin' not in imshow_kwargs:
180            imshow_kwargs['origin'] = 'lower'
181        return ImAnnotateController(ax, plot_function, extent, resolution=resolution, imshow_kwargs=imshow_kwargs or {}, **annotation_kwargs)
182
183    def contour(self, ax, X, Y, C, levels=20, label=None, **kwargs):
184        return ax.contour(X, Y, C, levels=np.linspace(C.min(), C.max(), levels), label=label, **kwargs)
185
186    def surface(self, ax, X, Y, Z, color=None, label=None, **kwargs):
187        return ax.plot_surface(X, Y, Z, label=label, **kwargs)
188
189    def fill_between(self, ax, X, lower, upper, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
190        return ax.fill_between(X, lower, upper, facecolor=color, label=label, **kwargs)
191
192    def fill_gradient(self, canvas, X, percentiles, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
193        ax = canvas
194        plots = []
195
196        if 'edgecolors' not in kwargs:
197            kwargs['edgecolors'] = 'none'
198
199        if 'facecolors' in kwargs:
200            color = kwargs.pop('facecolors')
201
202        if 'array' in kwargs:
203            array = kwargs.pop('array')
204        else:
205            array = 1.-np.abs(np.linspace(-.97, .97, len(percentiles)-1))
206
207        if 'alpha' in kwargs:
208            alpha = kwargs.pop('alpha')
209        else:
210            alpha = .8
211
212        if 'cmap' in kwargs:
213            cmap = kwargs.pop('cmap')
214        else:
215            cmap = LinearSegmentedColormap.from_list('WhToColor', (color, color), N=array.size)
216        cmap._init()
217        cmap._lut[:-3, -1] = alpha*array
218
219        kwargs['facecolors'] = [cmap(i) for i in np.linspace(0,1,cmap.N)]
220
221        # pop where from kwargs
222        where = kwargs.pop('where') if 'where' in kwargs else None
223        # pop interpolate, which we actually do not do here!
224        if 'interpolate' in kwargs: kwargs.pop('interpolate')
225
226        def pairwise(iterable):
227            "s -> (s0,s1), (s1,s2), (s2, s3), ..."
228            from itertools import tee
229            #try:
230            #    from itertools import izip as zip
231            #except ImportError:
232            #    pass
233            a, b = tee(iterable)
234            next(b, None)
235            return zip(a, b)
236
237        polycol = []
238        for y1, y2 in pairwise(percentiles):
239            try:
240                from matplotlib.cbook import contiguous_regions
241            except ImportError:
242                from matplotlib.mlab import contiguous_regions
243            # Handle united data, such as dates
244            ax._process_unit_info(xdata=X, ydata=y1)
245            ax._process_unit_info(ydata=y2)
246            # Convert the arrays so we can work with them
247            from numpy import ma
248            x = ma.masked_invalid(ax.convert_xunits(X))
249            y1 = ma.masked_invalid(ax.convert_yunits(y1))
250            y2 = ma.masked_invalid(ax.convert_yunits(y2))
251
252            if y1.ndim == 0:
253                y1 = np.ones_like(x) * y1
254            if y2.ndim == 0:
255                y2 = np.ones_like(x) * y2
256
257            if where is None:
258                where = np.ones(len(x), np.bool)
259            else:
260                where = np.asarray(where, np.bool)
261
262            if not (x.shape == y1.shape == y2.shape == where.shape):
263                raise ValueError("Argument dimensions are incompatible")
264
265            from functools import reduce
266            mask = reduce(ma.mask_or, [ma.getmask(a) for a in (x, y1, y2)])
267            if mask is not ma.nomask:
268                where &= ~mask
269
270            polys = []
271            for ind0, ind1 in contiguous_regions(where):
272                xslice = x[ind0:ind1]
273                y1slice = y1[ind0:ind1]
274                y2slice = y2[ind0:ind1]
275
276                if not len(xslice):
277                    continue
278
279                N = len(xslice)
280                p = np.zeros((2 * N + 2, 2), np.float)
281
282                # the purpose of the next two lines is for when y2 is a
283                # scalar like 0 and we want the fill to go all the way
284                # down to 0 even if none of the y1 sample points do
285                start = xslice[0], y2slice[0]
286                end = xslice[-1], y2slice[-1]
287
288                p[0] = start
289                p[N + 1] = end
290
291                p[1:N + 1, 0] = xslice
292                p[1:N + 1, 1] = y1slice
293                p[N + 2:, 0] = xslice[::-1]
294                p[N + 2:, 1] = y2slice[::-1]
295
296                polys.append(p)
297            polycol.extend(polys)
298        from matplotlib.collections import PolyCollection
299        if 'zorder' not in kwargs:
300            kwargs['zorder'] = 0
301        plots.append(PolyCollection(polycol, label=label, **kwargs))
302        ax.add_collection(plots[-1], autolim=True)
303        ax.autoscale_view()
304        return plots
305