1# Copyright (c) 2018, GPy authors (see AUTHORS.txt).
2# Licensed under the BSD 3-clause license (see LICENSE.txt)
3from .kern import Kern
4import numpy as np
5from paramz.caching import Cache_this
6
7class DiffKern(Kern):
8    """
9    Diff kernel is a thin wrapper for using partial derivatives of kernels as kernels. Eg. in combination with
10    Multioutput kernel this allows the user to train GPs with observations of latent function and latent
11    function derivatives. NOTE: DiffKern only works when used with Multioutput kernel. Do not use the kernel as standalone
12
13    The parameters the kernel needs are:
14    -'base_kern': a member of Kernel class that is used for observations
15    -'dimension': integer that indigates in which dimensions the partial derivative observations are
16    """
17    def __init__(self, base_kern, dimension):
18        super(DiffKern, self).__init__(base_kern.active_dims.size, base_kern.active_dims, name='DiffKern')
19        self.base_kern = base_kern
20        self.dimension = dimension
21
22    def parameters_changed(self):
23        self.base_kern.parameters_changed()
24
25    @Cache_this(limit=3, ignore_args=())
26    def K(self, X, X2=None, dimX2 = None): #X in dimension self.dimension
27        if X2 is None:
28            X2 = X
29        if dimX2 is None:
30            dimX2 = self.dimension
31        return self.base_kern.dK2_dXdX2(X,X2, self.dimension, dimX2)
32
33    @Cache_this(limit=3, ignore_args=())
34    def Kdiag(self, X):
35        return np.diag(self.base_kern.dK2_dXdX2(X,X, self.dimension, self.dimension))
36
37    @Cache_this(limit=3, ignore_args=())
38    def dK_dX_wrap(self, X, X2): #X in dimension self.dimension
39        return self.base_kern.dK_dX(X,X2, self.dimension)
40
41    @Cache_this(limit=3, ignore_args=())
42    def dK_dX2_wrap(self, X, X2): #X in dimension self.dimension
43        return self.base_kern.dK_dX2(X,X2, self.dimension)
44
45    def reset_gradients(self):
46        self.base_kern.reset_gradients()
47
48    @property
49    def gradient(self):
50        return self.base_kern.gradient
51
52    @gradient.setter
53    def gradient(self, gradient):
54        self.base_kern.gradient = gradient
55
56    def update_gradients_full(self, dL_dK, X, X2=None, dimX2=None):
57        if dimX2 is None:
58            dimX2 = self.dimension
59        gradients = self.base_kern.dgradients2_dXdX2(X,X2,self.dimension,dimX2)
60        self.base_kern.update_gradients_direct(*[self._convert_gradients(dL_dK, gradient) for gradient in gradients])
61
62    def update_gradients_diag(self, dL_dK_diag, X):
63        gradients = self.base_kern.dgradients2_dXdX2(X,X, self.dimension, self.dimension)
64        self.base_kern.update_gradients_direct(*[self._convert_gradients(dL_dK_diag, gradient, f=np.diag) for gradient in gradients])
65
66    def update_gradients_dK_dX(self, dL_dK, X, X2=None):
67        if X2 is None:
68            X2 = X
69        gradients = self.base_kern.dgradients_dX(X,X2, self.dimension)
70        self.base_kern.update_gradients_direct(*[self._convert_gradients(dL_dK, gradient) for gradient in gradients])
71
72    def update_gradients_dK_dX2(self, dL_dK, X, X2=None):
73        gradients = self.base_kern.dgradients_dX2(X,X2, self.dimension)
74        self.base_kern.update_gradients_direct(*[self._convert_gradients(dL_dK, gradient) for gradient in gradients])
75
76    def gradients_X(self, dL_dK, X, X2):
77        tmp = self.base_kern.gradients_XX(dL_dK, X, X2)[:,:,:, self.dimension]
78        return np.sum(tmp, axis=1)
79
80    def gradients_X2(self, dL_dK, X, X2):
81        tmp = self.base_kern.gradients_XX(dL_dK, X, X2)[:, :, self.dimension, :]
82        return np.sum(tmp, axis=1)
83
84    def _convert_gradients(self, l,g, f = lambda x:x):
85        if type(g) is np.ndarray:
86            return np.sum(f(l)*f(g))
87        else:
88            return np.array([np.sum(f(l)*f(gi)) for gi in g])