1#=============================================================================== 2# Copyright (c) 2015, Max Zwiessele 3# All rights reserved. 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are met: 7# 8# * Redistributions of source code must retain the above copyright notice, this 9# list of conditions and the following disclaimer. 10# 11# * Redistributions in binary form must reproduce the above copyright notice, 12# this list of conditions and the following disclaimer in the documentation 13# and/or other materials provided with the distribution. 14# 15# * Neither the name of GPy.plotting.gpy_plot.kernel_plots nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29#=============================================================================== 30import numpy as np 31from . import plotting_library as pl 32from .. import Tango 33from .plot_util import update_not_existing_kwargs, helper_for_plot_data 34from ...kern.src.kern import Kern, CombinationKernel 35 36def plot_ARD(kernel, filtering=None, legend=False, canvas=None, **kwargs): 37 """ 38 If an ARD kernel is present, plot a bar representation using matplotlib 39 40 :param fignum: figure number of the plot 41 :param filtering: list of names, which to use for plotting ARD parameters. 42 Only kernels which match names in the list of names in filtering 43 will be used for plotting. 44 :type filtering: list of names to use for ARD plot 45 """ 46 Tango.reset() 47 48 ard_params = np.atleast_2d(kernel.input_sensitivity(summarize=False)) 49 bottom = 0 50 last_bottom = bottom 51 52 x = np.arange(kernel._effective_input_dim) 53 54 parts = [] 55 def visit(x): 56 if (not isinstance(x, CombinationKernel)) and isinstance(x, Kern): 57 parts.append(x) 58 kernel.traverse(visit) 59 60 if filtering is None: 61 filtering = [k.name for k in parts] 62 63 bars = [] 64 kwargs = update_not_existing_kwargs(kwargs, pl().defaults.ard) 65 66 67 if canvas is None: 68 canvas, kwargs = pl().new_canvas(xlim=(-.5, kernel._effective_input_dim-.5), xlabel='input dimension', ylabel='ard contribution', **kwargs) 69 70 for i in range(ard_params.shape[0]): 71 if parts[i].name in filtering: 72 c = Tango.nextMedium() 73 bars.append(pl().barplot(canvas, x, 74 ard_params[i,:], color=c, 75 label=parts[i].name, 76 bottom=bottom, **kwargs)) 77 last_bottom = ard_params[i,:] 78 bottom += last_bottom 79 else: 80 print("filtering out {}".format(parts[i].name)) 81 82 #add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-last_bottom) 83 84 return pl().add_to_canvas(canvas, bars, legend=legend) 85 86def plot_covariance(kernel, x=None, label=None, 87 plot_limits=None, visible_dims=None, resolution=None, 88 projection='2d', levels=20, **kwargs): 89 """ 90 Plot a kernel covariance w.r.t. another x. 91 92 :param array-like x: the value to use for the other kernel argument (kernels are a function of two variables!) 93 :param plot_limits: the range over which to plot the kernel 94 :type plot_limits: Either (xmin, xmax) for 1D or (xmin, xmax, ymin, ymax) / ((xmin, xmax), (ymin, ymax)) for 2D 95 :param array-like visible_dims: input dimensions (!) to use for x. Make sure to select 2 or less dimensions to plot. 96 :resolution: the resolution of the lines used in plotting. for 2D this defines the grid for kernel evaluation. 97 :param {2d|3d} projection: What projection shall we use to plot the kernel? 98 :param int levels: for 2D projection, how many levels for the contour plot to use? 99 :param kwargs: valid kwargs for your specific plotting library 100 """ 101 X = np.ones((2, kernel._effective_input_dim)) * [[-3], [3]] 102 _, free_dims, Xgrid, xx, yy, _, _, resolution = helper_for_plot_data(kernel, X, plot_limits, visible_dims, None, resolution) 103 104 from numbers import Number 105 if x is None: 106 from ...kern.src.stationary import Stationary 107 x = np.ones((1, kernel._effective_input_dim)) * (not isinstance(kernel, Stationary)) 108 elif isinstance(x, Number): 109 x = np.ones((1, kernel._effective_input_dim))*x 110 K = kernel.K(Xgrid, x) 111 112 if projection == '3d': 113 xlabel = 'X[:,0]' 114 ylabel = 'X[:,1]' 115 zlabel = "k(X, {!s})".format(np.asanyarray(x).tolist()) 116 else: 117 xlabel = 'X' 118 ylabel = "k(X, {!s})".format(np.asanyarray(x).tolist()) 119 zlabel = None 120 121 canvas, kwargs = pl().new_canvas(projection=projection, xlabel=xlabel, ylabel=ylabel, zlabel=zlabel, **kwargs) 122 123 if len(free_dims)<=2: 124 if len(free_dims)==1: 125 # 1D plotting: 126 update_not_existing_kwargs(kwargs, pl().defaults.meanplot_1d) # @UndefinedVariable 127 plots = dict(covariance=[pl().plot(canvas, Xgrid[:, free_dims], K, label=label, **kwargs)]) 128 else: 129 if projection == '2d': 130 update_not_existing_kwargs(kwargs, pl().defaults.meanplot_2d) # @UndefinedVariable 131 plots = dict(covariance=[pl().contour(canvas, xx[:, 0], yy[0, :], 132 K.reshape(resolution, resolution), 133 levels=levels, label=label, **kwargs)]) 134 elif projection == '3d': 135 update_not_existing_kwargs(kwargs, pl().defaults.meanplot_3d) # @UndefinedVariable 136 plots = dict(covariance=[pl().surface(canvas, xx, yy, 137 K.reshape(resolution, resolution), 138 label=label, 139 **kwargs)]) 140 return pl().add_to_canvas(canvas, plots) 141 142 else: 143 raise NotImplementedError("Cannot plot a kernel with more than two input dimensions") 144