1#   Copyright 2020 The PyMC Developers
2#
3#   Licensed under the Apache License, Version 2.0 (the "License");
4#   you may not use this file except in compliance with the License.
5#   You may obtain a copy of the License at
6#
7#       http://www.apache.org/licenses/LICENSE-2.0
8#
9#   Unless required by applicable law or agreed to in writing, software
10#   distributed under the License is distributed on an "AS IS" BASIS,
11#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#   See the License for the specific language governing permissions and
13#   limitations under the License.
14
15from theano import tensor as tt
16
17from pymc3.theanof import floatX
18from pymc3.variational.opvi import TestFunction
19
20__all__ = ["rbf"]
21
22
23class Kernel(TestFunction):
24    """
25    Dummy base class for kernel SVGD in case we implement more
26
27    .. math::
28
29        f(x) -> (k(x,.), \nabla_x k(x,.))
30
31    """
32
33
34class RBF(Kernel):
35    def __call__(self, X):
36        XY = X.dot(X.T)
37        x2 = tt.sum(X ** 2, axis=1).dimshuffle(0, "x")
38        X2e = tt.repeat(x2, X.shape[0], axis=1)
39        H = X2e + X2e.T - 2.0 * XY
40
41        V = tt.sort(H.flatten())
42        length = V.shape[0]
43        # median distance
44        m = tt.switch(
45            tt.eq((length % 2), 0),
46            # if even vector
47            tt.mean(V[((length // 2) - 1) : ((length // 2) + 1)]),
48            # if odd vector
49            V[length // 2],
50        )
51
52        h = 0.5 * m / tt.log(floatX(H.shape[0]) + floatX(1))
53
54        #  RBF
55        Kxy = tt.exp(-H / h / 2.0)
56
57        # Derivative
58        dxkxy = -tt.dot(Kxy, X)
59        sumkxy = tt.sum(Kxy, axis=-1, keepdims=True)
60        dxkxy = tt.add(dxkxy, tt.mul(X, sumkxy)) / h
61
62        return Kxy, dxkxy
63
64
65rbf = RBF()
66