1# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
2# Licensed under the BSD 3-clause license (see LICENSE.txt)
3
4import numpy as np
5from scipy.special import cbrt
6from .config import *
7
8_lim_val = np.finfo(np.float64).max
9_lim_val_exp = np.log(_lim_val)
10_lim_val_square = np.sqrt(_lim_val)
11#_lim_val_cube = cbrt(_lim_val)
12_lim_val_cube = np.nextafter(_lim_val**(1/3.0), -np.inf)
13_lim_val_quad = np.nextafter(_lim_val**(1/4.0), -np.inf)
14_lim_val_three_times = np.nextafter(_lim_val/3.0, -np.inf)
15
16def safe_exp(f):
17    clip_f = np.clip(f, -np.inf, _lim_val_exp)
18    return np.exp(clip_f)
19
20def safe_square(f):
21    f = np.clip(f, -np.inf, _lim_val_square)
22    return f**2
23
24def safe_cube(f):
25    f = np.clip(f, -np.inf, _lim_val_cube)
26    return f**3
27
28def safe_quad(f):
29    f = np.clip(f, -np.inf, _lim_val_quad)
30    return f**4
31
32def safe_three_times(f):
33    f = np.clip(f, -np.inf, _lim_val_three_times)
34    return 3*f
35
36def chain_1(df_dg, dg_dx):
37    """
38    Generic chaining function for first derivative
39
40    .. math::
41        \\frac{d(f . g)}{dx} = \\frac{df}{dg} \\frac{dg}{dx}
42    """
43    if np.all(dg_dx==1.):
44        return df_dg
45    return df_dg * dg_dx
46
47def chain_2(d2f_dg2, dg_dx, df_dg, d2g_dx2):
48    """
49    Generic chaining function for second derivative
50
51    .. math::
52        \\frac{d^{2}(f . g)}{dx^{2}} = \\frac{d^{2}f}{dg^{2}}(\\frac{dg}{dx})^{2} + \\frac{df}{dg}\\frac{d^{2}g}{dx^{2}}
53    """
54    if np.all(dg_dx==1.) and np.all(d2g_dx2 == 0):
55        return d2f_dg2
56    dg_dx_2 = np.clip(dg_dx, -np.inf, _lim_val_square)**2
57    #dg_dx_2 = dg_dx**2
58    return d2f_dg2*(dg_dx_2) + df_dg*d2g_dx2
59
60def chain_3(d3f_dg3, dg_dx, d2f_dg2, d2g_dx2, df_dg, d3g_dx3):
61    """
62    Generic chaining function for third derivative
63
64    .. math::
65        \\frac{d^{3}(f . g)}{dx^{3}} = \\frac{d^{3}f}{dg^{3}}(\\frac{dg}{dx})^{3} + 3\\frac{d^{2}f}{dg^{2}}\\frac{dg}{dx}\\frac{d^{2}g}{dx^{2}} + \\frac{df}{dg}\\frac{d^{3}g}{dx^{3}}
66    """
67    if np.all(dg_dx==1.) and np.all(d2g_dx2==0) and np.all(d3g_dx3==0):
68        return d3f_dg3
69    dg_dx_3 = np.clip(dg_dx, -np.inf, _lim_val_cube)**3
70    return d3f_dg3*(dg_dx_3) + 3*d2f_dg2*dg_dx*d2g_dx2 + df_dg*d3g_dx3
71
72def opt_wrapper(m, **kwargs):
73    """
74    Thit function just wraps the optimization procedure of a GPy
75    object so that optimize() pickleable (necessary for multiprocessing).
76    """
77    m.optimize(**kwargs)
78    return m.optimization_runs[-1]
79
80
81def linear_grid(D, n = 100, min_max = (-100, 100)):
82    """
83    Creates a D-dimensional grid of n linearly spaced points
84
85    :param D: dimension of the grid
86    :param n: number of points
87    :param min_max: (min, max) list
88
89    """
90
91    g = np.linspace(min_max[0], min_max[1], n)
92    G = np.ones((n, D))
93
94    return G*g[:,None]
95
96def kmm_init(X, m = 10):
97    """
98    This is the same initialization algorithm that is used
99    in Kmeans++. It's quite simple and very useful to initialize
100    the locations of the inducing points in sparse GPs.
101
102    :param X: data
103    :param m: number of inducing points
104
105    """
106
107    # compute the distances
108    XXT = np.dot(X, X.T)
109    D = (-2.*XXT + np.diag(XXT)[:,np.newaxis] + np.diag(XXT)[np.newaxis,:])
110
111    # select the first point
112    s = np.random.permutation(X.shape[0])[0]
113    inducing = [s]
114    prob = D[s]/D[s].sum()
115
116    for z in range(m-1):
117        s = np.random.multinomial(1, prob.flatten()).argmax()
118        inducing.append(s)
119        prob = D[s]/D[s].sum()
120
121    inducing = np.array(inducing)
122    return X[inducing]
123
124### make a parameter to its corresponding array:
125def param_to_array(*param):
126    """
127    Convert an arbitrary number of parameters to :class:ndarray class objects.
128    This is for converting parameter objects to numpy arrays, when using
129    scipy.weave.inline routine.  In scipy.weave.blitz there is no automatic
130    array detection (even when the array inherits from :class:ndarray)
131    """
132    import warnings
133    warnings.warn("Please use param.values, as this function will be deprecated in the next release.", DeprecationWarning)
134    assert len(param) > 0, "At least one parameter needed"
135    if len(param) == 1:
136        return param[0].view(np.ndarray)
137    return [x.view(np.ndarray) for x in param]
138
139def blockify_hessian(func):
140    def wrapper_func(self, *args, **kwargs):
141        # Invoke the wrapped function first
142        retval = func(self, *args, **kwargs)
143        # Now do something here with retval and/or action
144        if self.not_block_really and (retval.shape[0] != retval.shape[1]):
145            return np.diagflat(retval)
146        else:
147            return retval
148    return wrapper_func
149
150def blockify_third(func):
151    def wrapper_func(self, *args, **kwargs):
152        # Invoke the wrapped function first
153        retval = func(self, *args, **kwargs)
154        # Now do something here with retval and/or action
155        if self.not_block_really and (len(retval.shape) < 3):
156            num_data = retval.shape[0]
157            d3_block_cache = np.zeros((num_data, num_data, num_data))
158            diag_slice = range(num_data)
159            d3_block_cache[diag_slice, diag_slice, diag_slice] = np.squeeze(retval)
160            return d3_block_cache
161        else:
162            return retval
163    return wrapper_func
164
165def blockify_dhess_dtheta(func):
166    def wrapper_func(self, *args, **kwargs):
167        # Invoke the wrapped function first
168        retval = func(self, *args, **kwargs)
169        # Now do something here with retval and/or action
170        if self.not_block_really and (len(retval.shape) < 3):
171            num_data = retval.shape[0]
172            num_params = retval.shape[-1]
173            dhess_dtheta = np.zeros((num_data, num_data, num_params))
174            diag_slice = range(num_data)
175            for param_ind in range(num_params):
176                dhess_dtheta[diag_slice, diag_slice, param_ind] = np.squeeze(retval[:,param_ind])
177            return dhess_dtheta
178        else:
179            return retval
180    return wrapper_func
181
182