1#===============================================================================
2# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt).
3# All rights reserved.
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are met:
7#
8# * Redistributions of source code must retain the above copyright notice, this
9#   list of conditions and the following disclaimer.
10#
11# * Redistributions in binary form must reproduce the above copyright notice,
12#   this list of conditions and the following disclaimer in the documentation
13#   and/or other materials provided with the distribution.
14#
15# * Neither the name of GPy nor the names of its
16#   contributors may be used to endorse or promote products derived from
17#   this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29#===============================================================================
30
31import numpy as np
32from scipy import sparse
33import itertools
34from ...models import WarpedGP
35
36def in_ipynb():
37    try:
38        cfg = get_ipython().config
39        return 'IPKernelApp' in cfg
40    except NameError:
41        return False
42
43def find_best_layout_for_subplots(num_subplots):
44    r, c = 1, 1
45    while (r*c) < num_subplots:
46        if (c==(r+1)) or (r==c):
47            c += 1
48        elif c==(r+2):
49            r += 1
50            c -= 1
51    return r, c
52
53def helper_predict_with_model(self, Xgrid, plot_raw, apply_link, percentiles, which_data_ycols, predict_kw, samples=0):
54    """
55    Make the right decisions for prediction with a model
56    based on the standard arguments of plotting.
57
58    This is quite complex and will take a while to understand,
59    so do not change anything in here lightly!!!
60    """
61    # Put some standards into the predict_kw so that prediction is done automatically:
62    if predict_kw is None:
63        predict_kw = {}
64    if 'likelihood' not in predict_kw:
65        if plot_raw:
66            from ...likelihoods import Gaussian
67            from ...likelihoods.link_functions import Identity
68            lik = Gaussian(Identity(), 1e-9) # Make the likelihood not add any noise
69        else:
70            lik = None
71        predict_kw['likelihood'] = lik
72    if 'Y_metadata' not in predict_kw:
73        predict_kw['Y_metadata'] = {}
74    if 'output_index' not in predict_kw['Y_metadata']:
75        predict_kw['Y_metadata']['output_index'] = Xgrid[:,-1:].astype(np.int)
76
77    mu, _ = self.predict(Xgrid, **predict_kw)
78
79    if percentiles is not None:
80        percentiles = self.predict_quantiles(Xgrid, quantiles=percentiles, **predict_kw)
81    else: percentiles = []
82
83    if samples > 0:
84        fsamples = self.posterior_samples(Xgrid, size=samples, **predict_kw)
85        fsamples = fsamples[:, which_data_ycols, :]
86    else:
87        fsamples = None
88
89    # Filter out the ycolums which we want to plot:
90    retmu = mu[:, which_data_ycols]
91    percs = [p[:, which_data_ycols] for p in percentiles]
92
93    if plot_raw and apply_link:
94        for i in range(len(which_data_ycols)):
95            retmu[:, [i]] = self.likelihood.gp_link.transf(mu[:, [i]])
96            for perc in percs:
97                perc[:, [i]] = self.likelihood.gp_link.transf(perc[:, [i]])
98            if fsamples is not None:
99                for s in range(fsamples.shape[-1]):
100                    fsamples[:, i, s] = self.likelihood.gp_link.transf(fsamples[:, i, s])
101    return retmu, percs, fsamples
102
103def helper_for_plot_data(self, X, plot_limits, visible_dims, fixed_inputs, resolution):
104    """
105    Figure out the data, free_dims and create an Xgrid for
106    the prediction.
107
108    This is only implemented for two dimensions for now!
109    """
110    #work out what the inputs are for plotting (1D or 2D)
111    if fixed_inputs is None:
112        fixed_inputs = []
113    fixed_dims = get_fixed_dims(fixed_inputs)
114    free_dims = get_free_dims(self, visible_dims, fixed_dims)
115
116    if len(free_dims) == 1:
117        #define the frame on which to plot
118        resolution = resolution or 200
119        Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution)
120        Xgrid = np.zeros((Xnew.shape[0],self.input_dim))
121        Xgrid[:,free_dims] = Xnew
122        for i,v in fixed_inputs:
123            Xgrid[:,i] = v
124        x = Xgrid
125        y = None
126    elif len(free_dims) == 2:
127        #define the frame for plotting on
128        resolution = resolution or 35
129        Xnew, x, y, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution)
130        Xgrid = np.zeros((Xnew.shape[0], self.input_dim))
131        Xgrid[:,free_dims] = Xnew
132        #xmin = Xgrid.min(0)[free_dims]
133        #xmax = Xgrid.max(0)[free_dims]
134        for i,v in fixed_inputs:
135            Xgrid[:,i] = v
136    else:
137        raise TypeError("calculated free_dims {} from visible_dims {} and fixed_dims {} is neither 1D nor 2D".format(free_dims, visible_dims, fixed_dims))
138    return fixed_dims, free_dims, Xgrid, x, y, xmin, xmax, resolution
139
140def scatter_label_generator(labels, X, visible_dims, marker=None):
141    ulabels = []
142    for lab in labels:
143        if not lab in ulabels:
144            ulabels.append(lab)
145    if marker is not None:
146        marker = itertools.cycle(list(marker))
147    else:
148        m = None
149
150    try:
151        input_1, input_2, input_3 = visible_dims
152    except:
153        try:
154            # tuple or int?
155            input_1, input_2 = visible_dims
156            input_3 = None
157        except:
158            input_1 = visible_dims
159            input_2 = input_3 = None
160
161    for ul in ulabels:
162        from numbers import Number
163        if isinstance(ul, str):
164            try:
165                this_label = unicode(ul)
166            except NameError:
167                #python3
168                this_label = ul
169        elif isinstance(ul, Number):
170            this_label = 'class {!s}'.format(ul)
171        else:
172            this_label = ul
173
174        if marker is not None:
175            m = next(marker)
176
177        index = np.nonzero(labels == ul)[0]
178
179        if input_2 is None:
180            x = X[index, input_1]
181            y = np.zeros(index.size)
182            z = None
183        elif input_3 is None:
184            x = X[index, input_1]
185            y = X[index, input_2]
186            z = None
187        else:
188            x = X[index, input_1]
189            y = X[index, input_2]
190            z = X[index, input_3]
191
192        yield x, y, z, this_label, index, m
193
194def subsample_X(X, labels, num_samples=1000):
195    """
196    Stratified subsampling if labels are given.
197    This means due to rounding errors you might get a little differences between the
198    num_samples and the returned subsampled X.
199    """
200    if X.shape[0] > num_samples:
201        print("Warning: subsampling X, as it has more samples then {}. X.shape={!s}".format(int(num_samples), X.shape))
202        if labels is not None:
203            subsample = []
204            for _, _, _, _, index, _ in scatter_label_generator(labels, X, (0, None, None)):
205                subsample.append(np.random.choice(index, size=max(2, int(index.size*(float(num_samples)/X.shape[0]))), replace=False))
206            subsample = np.hstack(subsample)
207        else:
208            subsample = np.random.choice(X.shape[0], size=1000, replace=False)
209        X = X[subsample]
210        labels = labels[subsample]
211        #=======================================================================
212        #     <<<WORK IN PROGRESS>>>
213        #     <<<DO NOT DELETE>>>
214        #     plt.close('all')
215        #     fig, ax = plt.subplots(1,1)
216        #     from GPy.plotting.matplot_dep.dim_reduction_plots import most_significant_input_dimensions
217        #     import matplotlib.patches as mpatches
218        #     i1, i2 = most_significant_input_dimensions(m, None)
219        #     xmin, xmax = 100, -100
220        #     ymin, ymax = 100, -100
221        #     legend_handles = []
222        #
223        #     X = m.X.mean[:, [i1, i2]]
224        #     X = m.X.variance[:, [i1, i2]]
225        #
226        #     xmin = X[:,0].min(); xmax = X[:,0].max()
227        #     ymin = X[:,1].min(); ymax = X[:,1].max()
228        #     range_ = [[xmin, xmax], [ymin, ymax]]
229        #     ul = np.unique(labels)
230        #
231        #     for i, l in enumerate(ul):
232        #         #cdict = dict(red  =[(0., colors[i][0], colors[i][0]), (1., colors[i][0], colors[i][0])],
233        #         #             green=[(0., colors[i][0], colors[i][1]), (1., colors[i][1], colors[i][1])],
234        #         #             blue =[(0., colors[i][0], colors[i][2]), (1., colors[i][2], colors[i][2])],
235        #         #             alpha=[(0., 0., .0), (.5, .5, .5), (1., .5, .5)])
236        #         #cmap = LinearSegmentedColormap('{}'.format(l), cdict)
237        #         cmap = LinearSegmentedColormap.from_list('cmap_{}'.format(str(l)), [colors[i], colors[i]], 255)
238        #         cmap._init()
239        #         #alphas = .5*(1+scipy.special.erf(np.linspace(-2,2, cmap.N+3)))#np.log(np.linspace(np.exp(0), np.exp(1.), cmap.N+3))
240        #         alphas = (scipy.special.erf(np.linspace(0,2.4, cmap.N+3)))#np.log(np.linspace(np.exp(0), np.exp(1.), cmap.N+3))
241        #         cmap._lut[:, -1] = alphas
242        #         print l
243        #         x, y = X[labels==l].T
244        #
245        #         heatmap, xedges, yedges = np.histogram2d(x, y, bins=300, range=range_)
246        #         #heatmap, xedges, yedges = np.histogram2d(x, y, bins=100)
247        #
248        #         im = ax.imshow(heatmap, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], cmap=cmap, aspect='auto', interpolation='nearest', label=str(l))
249        #         legend_handles.append(mpatches.Patch(color=colors[i], label=l))
250        #     ax.set_xlim(xmin, xmax)
251        #     ax.set_ylim(ymin, ymax)
252        #     plt.legend(legend_handles, [l.get_label() for l in legend_handles])
253        #     plt.draw()
254        #     plt.show()
255        #=======================================================================
256    return X, labels
257
258
259def update_not_existing_kwargs(to_update, update_from):
260    """
261    This function updates the keyword aguments from update_from in
262    to_update, only if the keys are not set in to_update.
263
264    This is used for updated kwargs from the default dicts.
265    """
266    if to_update is None:
267        to_update = {}
268    to_update.update({k:v for k,v in update_from.items() if k not in to_update})
269    return to_update
270
271def get_x_y_var(model):
272    """
273    Either the the data from a model as
274    X the inputs,
275    X_variance the variance of the inputs ([default: None])
276    and Y the outputs
277
278    If (X, X_variance, Y) is given, this just returns.
279
280    :returns: (X, X_variance, Y)
281    """
282    # model given
283    if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
284        X = model.X.mean.values
285        X_variance = model.X.variance.values
286    else:
287        try:
288            X = model.X.values
289        except AttributeError:
290            X = model.X
291        X_variance = None
292    try:
293        Y = model.Y.values
294    except AttributeError:
295        Y = model.Y
296
297    if isinstance(model, WarpedGP) and not model.predict_in_warped_space:
298        Y = model.Y_normalized
299
300    if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
301    return X, X_variance, Y
302
303def get_free_dims(model, visible_dims, fixed_dims):
304    """
305    work out what the inputs are for plotting (1D or 2D)
306
307    The visible dimensions are the dimensions, which are visible.
308    the fixed_dims are the fixed dimensions for this.
309
310    The free_dims are then the visible dims without the fixed dims.
311    """
312    if visible_dims is None:
313        visible_dims = np.arange(model.input_dim)
314    dims = np.asanyarray(visible_dims)
315    if fixed_dims is not None:
316        dims = [dim for dim in dims if dim not in fixed_dims]
317    return np.asanyarray([dim for dim in dims if dim is not None])
318
319
320def get_fixed_dims(fixed_inputs):
321    """
322    Work out the fixed dimensions from the fixed_inputs list of tuples.
323    """
324    return np.array([i for i,_ in fixed_inputs])
325
326def get_which_data_ycols(model, which_data_ycols):
327    """
328    Helper to get the data columns to plot.
329    """
330    if which_data_ycols == 'all' or which_data_ycols is None:
331        return np.arange(model.output_dim)
332    return which_data_ycols
333
334def get_which_data_rows(model, which_data_rows):
335    """
336    Helper to get the data rows to plot.
337    """
338    if which_data_rows == 'all' or which_data_rows is None:
339        return slice(None)
340    return which_data_rows
341
342def x_frame1D(X,plot_limits=None,resolution=None):
343    """
344    Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
345    """
346    assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs"
347    if plot_limits is None:
348        from GPy.core.parameterization.variational import VariationalPosterior
349        if isinstance(X, VariationalPosterior):
350            xmin,xmax = X.mean.min(0),X.mean.max(0)
351        else:
352            xmin,xmax = X.min(0),X.max(0)
353        xmin, xmax = xmin-0.25*(xmax-xmin), xmax+0.25*(xmax-xmin)
354    elif len(plot_limits) == 2:
355        xmin, xmax = map(np.atleast_1d, plot_limits)
356    else:
357        raise ValueError("Bad limits for plotting")
358
359    Xnew = np.linspace(float(xmin),float(xmax),int(resolution) or 200)[:,None]
360    return Xnew, xmin, xmax
361
362def x_frame2D(X,plot_limits=None,resolution=None):
363    """
364    Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
365    """
366    assert X.shape[1]==2, "x_frame2D is defined for two-dimensional inputs"
367    if plot_limits is None:
368        xmin, xmax = X.min(0), X.max(0)
369        xmin, xmax = xmin-0.075*(xmax-xmin), xmax+0.075*(xmax-xmin)
370    elif len(plot_limits) == 2:
371        xmin, xmax = plot_limits
372        try:
373            xmin = xmin[0], xmin[1]
374        except:
375            # only one limit given, copy over to other lim
376            xmin = [plot_limits[0], plot_limits[0]]
377            xmax = [plot_limits[1], plot_limits[1]]
378    elif len(plot_limits) == 4:
379        xmin, xmax = (plot_limits[0], plot_limits[2]), (plot_limits[1], plot_limits[3])
380    else:
381        raise ValueError("Bad limits for plotting")
382
383    resolution = resolution or 50
384    xx, yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
385    Xnew = np.c_[xx.flat, yy.flat]
386    return Xnew, xx, yy, xmin, xmax
387