1# Copyright 2020 The PyMC Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from theano import tensor as tt 16 17from pymc3.theanof import floatX 18from pymc3.variational.opvi import TestFunction 19 20__all__ = ["rbf"] 21 22 23class Kernel(TestFunction): 24 """ 25 Dummy base class for kernel SVGD in case we implement more 26 27 .. math:: 28 29 f(x) -> (k(x,.), \nabla_x k(x,.)) 30 31 """ 32 33 34class RBF(Kernel): 35 def __call__(self, X): 36 XY = X.dot(X.T) 37 x2 = tt.sum(X ** 2, axis=1).dimshuffle(0, "x") 38 X2e = tt.repeat(x2, X.shape[0], axis=1) 39 H = X2e + X2e.T - 2.0 * XY 40 41 V = tt.sort(H.flatten()) 42 length = V.shape[0] 43 # median distance 44 m = tt.switch( 45 tt.eq((length % 2), 0), 46 # if even vector 47 tt.mean(V[((length // 2) - 1) : ((length // 2) + 1)]), 48 # if odd vector 49 V[length // 2], 50 ) 51 52 h = 0.5 * m / tt.log(floatX(H.shape[0]) + floatX(1)) 53 54 # RBF 55 Kxy = tt.exp(-H / h / 2.0) 56 57 # Derivative 58 dxkxy = -tt.dot(Kxy, X) 59 sumkxy = tt.sum(Kxy, axis=-1, keepdims=True) 60 dxkxy = tt.add(dxkxy, tt.mul(X, sumkxy)) / h 61 62 return Kxy, dxkxy 63 64 65rbf = RBF() 66