1#===============================================================================
2# Copyright (c) 2015, Max Zwiessele
3# All rights reserved.
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are met:
7#
8# * Redistributions of source code must retain the above copyright notice, this
9#   list of conditions and the following disclaimer.
10#
11# * Redistributions in binary form must reproduce the above copyright notice,
12#   this list of conditions and the following disclaimer in the documentation
13#   and/or other materials provided with the distribution.
14#
15# * Neither the name of GPy.plotting.gpy_plot.kernel_plots nor the names of its
16#   contributors may be used to endorse or promote products derived from
17#   this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29#===============================================================================
30import numpy as np
31from . import plotting_library as pl
32from .. import Tango
33from .plot_util import update_not_existing_kwargs, helper_for_plot_data
34from ...kern.src.kern import Kern, CombinationKernel
35
36def plot_ARD(kernel, filtering=None, legend=False, canvas=None, **kwargs):
37    """
38    If an ARD kernel is present, plot a bar representation using matplotlib
39
40    :param fignum: figure number of the plot
41    :param filtering: list of names, which to use for plotting ARD parameters.
42                      Only kernels which match names in the list of names in filtering
43                      will be used for plotting.
44    :type filtering: list of names to use for ARD plot
45    """
46    Tango.reset()
47
48    ard_params = np.atleast_2d(kernel.input_sensitivity(summarize=False))
49    bottom = 0
50    last_bottom = bottom
51
52    x = np.arange(kernel._effective_input_dim)
53
54    parts = []
55    def visit(x):
56        if (not isinstance(x, CombinationKernel)) and isinstance(x, Kern):
57            parts.append(x)
58    kernel.traverse(visit)
59
60    if filtering is None:
61        filtering = [k.name for k in parts]
62
63    bars = []
64    kwargs = update_not_existing_kwargs(kwargs, pl().defaults.ard)
65
66
67    if canvas is None:
68        canvas, kwargs = pl().new_canvas(xlim=(-.5, kernel._effective_input_dim-.5), xlabel='input dimension', ylabel='ard contribution', **kwargs)
69
70    for i in range(ard_params.shape[0]):
71        if parts[i].name in filtering:
72            c = Tango.nextMedium()
73            bars.append(pl().barplot(canvas, x,
74                                     ard_params[i,:], color=c,
75                                     label=parts[i].name,
76                                     bottom=bottom, **kwargs))
77            last_bottom = ard_params[i,:]
78            bottom += last_bottom
79        else:
80            print("filtering out {}".format(parts[i].name))
81
82    #add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-last_bottom)
83
84    return pl().add_to_canvas(canvas, bars, legend=legend)
85
86def plot_covariance(kernel, x=None, label=None,
87             plot_limits=None, visible_dims=None, resolution=None,
88             projection='2d', levels=20, **kwargs):
89    """
90    Plot a kernel covariance w.r.t. another x.
91
92    :param array-like x: the value to use for the other kernel argument (kernels are a function of two variables!)
93    :param plot_limits: the range over which to plot the kernel
94    :type plot_limits: Either (xmin, xmax) for 1D or (xmin, xmax, ymin, ymax) / ((xmin, xmax), (ymin, ymax)) for 2D
95    :param array-like visible_dims: input dimensions (!) to use for x. Make sure to select 2 or less dimensions to plot.
96    :resolution: the resolution of the lines used in plotting. for 2D this defines the grid for kernel evaluation.
97    :param {2d|3d} projection: What projection shall we use to plot the kernel?
98    :param int levels: for 2D projection, how many levels for the contour plot to use?
99    :param kwargs:  valid kwargs for your specific plotting library
100    """
101    X = np.ones((2, kernel._effective_input_dim)) * [[-3], [3]]
102    _, free_dims, Xgrid, xx, yy, _, _, resolution = helper_for_plot_data(kernel, X, plot_limits, visible_dims, None, resolution)
103
104    from numbers import Number
105    if x is None:
106        from ...kern.src.stationary import Stationary
107        x = np.ones((1, kernel._effective_input_dim)) * (not isinstance(kernel, Stationary))
108    elif isinstance(x, Number):
109        x = np.ones((1, kernel._effective_input_dim))*x
110    K = kernel.K(Xgrid, x)
111
112    if projection == '3d':
113        xlabel = 'X[:,0]'
114        ylabel = 'X[:,1]'
115        zlabel = "k(X, {!s})".format(np.asanyarray(x).tolist())
116    else:
117        xlabel = 'X'
118        ylabel = "k(X, {!s})".format(np.asanyarray(x).tolist())
119        zlabel = None
120
121    canvas, kwargs = pl().new_canvas(projection=projection, xlabel=xlabel, ylabel=ylabel, zlabel=zlabel, **kwargs)
122
123    if len(free_dims)<=2:
124        if len(free_dims)==1:
125            # 1D plotting:
126            update_not_existing_kwargs(kwargs, pl().defaults.meanplot_1d)  # @UndefinedVariable
127            plots = dict(covariance=[pl().plot(canvas, Xgrid[:, free_dims], K, label=label, **kwargs)])
128        else:
129            if projection == '2d':
130                update_not_existing_kwargs(kwargs, pl().defaults.meanplot_2d)  # @UndefinedVariable
131                plots = dict(covariance=[pl().contour(canvas, xx[:, 0], yy[0, :],
132                                               K.reshape(resolution, resolution),
133                                               levels=levels, label=label, **kwargs)])
134            elif projection == '3d':
135                update_not_existing_kwargs(kwargs, pl().defaults.meanplot_3d)  # @UndefinedVariable
136                plots = dict(covariance=[pl().surface(canvas, xx, yy,
137                                               K.reshape(resolution, resolution),
138                                               label=label,
139                                               **kwargs)])
140        return pl().add_to_canvas(canvas, plots)
141
142    else:
143        raise NotImplementedError("Cannot plot a kernel with more than two input dimensions")
144