1# Copyright (c) 2017, GPy authors (see AUTHORS.txt). 2# Licensed under the BSD 3-clause license (see LICENSE.txt) 3 4from GPy.util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri,pdinv, dpotri 5from GPy.util import diag 6from GPy.core.parameterization.variational import VariationalPosterior 7import numpy as np 8from GPy.inference.latent_function_inference import LatentFunctionInference 9from GPy.inference.latent_function_inference.posterior import Posterior 10from .vardtc_svi_multiout import PosteriorMultioutput 11log_2_pi = np.log(2*np.pi) 12 13 14class VarDTC_SVI_Multiout_Miss(LatentFunctionInference): 15 """ 16 The VarDTC inference method for Multi-output GP regression with missing data (GPy.models.GPMultioutRegressionMD) 17 """ 18 const_jitter = 1e-6 19 20 def get_trYYT(self, Y): 21 return np.sum(np.square(Y)) 22 23 def get_YYTfactor(self, Y): 24 N, D = Y.shape 25 if (N>=D): 26 return Y.view(np.ndarray) 27 else: 28 return jitchol(tdot(Y)) 29 30 def gatherPsiStat(self, kern, X, Z, uncertain_inputs): 31 32 if uncertain_inputs: 33 psi0 = kern.psi0(Z, X) 34 psi1 = kern.psi1(Z, X) 35 psi2 = kern.psi2n(Z, X) 36 else: 37 psi0 = kern.Kdiag(X) 38 psi1 = kern.K(X, Z) 39 psi2 = psi1[:,:,None]*psi1[:,None,:] 40 41 return psi0, psi1, psi2 42 43 def _init_grad_dict(self, N, D, Mr, Mc): 44 grad_dict = { 45 'dL_dthetaL': np.zeros(D), 46 'dL_dqU_mean': np.zeros((Mc,Mr)), 47 'dL_dqU_var_c':np.zeros((Mc,Mc)), 48 'dL_dqU_var_r':np.zeros((Mr,Mr)), 49 'dL_dKuu_c': np.zeros((Mc,Mc)), 50 'dL_dKuu_r': np.zeros((Mr,Mr)), 51 'dL_dpsi0_c': np.zeros(N), 52 'dL_dpsi1_c': np.zeros((N,Mc)), 53 'dL_dpsi2_c': np.zeros((N,Mc,Mc)), 54 'dL_dpsi0_r': np.zeros(D), 55 'dL_dpsi1_r': np.zeros((D,Mr)), 56 'dL_dpsi2_r': np.zeros((D,Mr,Mr)), 57 } 58 return grad_dict 59 60 def inference_d(self, d, beta, Y, indexD, grad_dict, mid_res, uncertain_inputs_r, uncertain_inputs_c, Mr, Mc): 61 62 idx_d = indexD==d 63 Y = Y[idx_d] 64 N, D = Y.shape[0], 1 65 beta = beta[d] 66 67 psi0_r, psi1_r, psi2_r = mid_res['psi0_r'], mid_res['psi1_r'], mid_res['psi2_r'] 68 psi0_c, psi1_c, psi2_c = mid_res['psi0_c'], mid_res['psi1_c'], mid_res['psi2_c'] 69 psi0_r, psi1_r, psi2_r = psi0_r[d], psi1_r[d:d+1], psi2_r[d] 70 psi0_c, psi1_c, psi2_c = psi0_c[idx_d].sum(), psi1_c[idx_d], psi2_c[idx_d].sum(0) 71 72 Lr = mid_res['Lr'] 73 Lc = mid_res['Lc'] 74 LcInvMLrInvT = mid_res['LcInvMLrInvT'] 75 LcInvScLcInvT = mid_res['LcInvScLcInvT'] 76 LrInvSrLrInvT = mid_res['LrInvSrLrInvT'] 77 78 79 LcInvPsi2_cLcInvT = backsub_both_sides(Lc, psi2_c,'right') 80 LrInvPsi2_rLrInvT = backsub_both_sides(Lr, psi2_r,'right') 81 LcInvPsi1_cT = dtrtrs(Lc, psi1_c.T)[0] 82 LrInvPsi1_rT = dtrtrs(Lr, psi1_r.T)[0] 83 84 tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT = (LrInvPsi2_rLrInvT*LrInvSrLrInvT).sum() 85 tr_LcInvPsi2_cLcInvT_LcInvScLcInvT = (LcInvPsi2_cLcInvT*LcInvScLcInvT).sum() 86 tr_LrInvPsi2_rLrInvT = np.trace(LrInvPsi2_rLrInvT) 87 tr_LcInvPsi2_cLcInvT = np.trace(LcInvPsi2_cLcInvT) 88 89 logL_A = - np.square(Y).sum() \ 90 - (LcInvMLrInvT.T.dot(LcInvPsi2_cLcInvT).dot(LcInvMLrInvT)*LrInvPsi2_rLrInvT).sum() \ 91 - tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT* tr_LcInvPsi2_cLcInvT_LcInvScLcInvT \ 92 + 2 * (Y * LcInvPsi1_cT.T.dot(LcInvMLrInvT).dot(LrInvPsi1_rT)).sum() - psi0_c * psi0_r \ 93 + tr_LrInvPsi2_rLrInvT * tr_LcInvPsi2_cLcInvT 94 95 logL = -N*D/2.*(np.log(2.*np.pi)-np.log(beta)) + beta/2.* logL_A 96 97 # ======= Gradients ===== 98 99 tmp = beta* LcInvPsi2_cLcInvT.dot(LcInvMLrInvT).dot(LrInvPsi2_rLrInvT).dot(LcInvMLrInvT.T) \ 100 + beta* tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT * LcInvPsi2_cLcInvT.dot(LcInvScLcInvT) \ 101 - beta* LcInvMLrInvT.dot(LrInvPsi1_rT).dot(Y.T).dot(LcInvPsi1_cT.T) \ 102 - beta/2. * tr_LrInvPsi2_rLrInvT* LcInvPsi2_cLcInvT 103 104 dL_dKuu_c = backsub_both_sides(Lc, tmp, 'left') 105 dL_dKuu_c += dL_dKuu_c.T 106 dL_dKuu_c *= 0.5 107 108 tmp = beta* LcInvMLrInvT.T.dot(LcInvPsi2_cLcInvT).dot(LcInvMLrInvT).dot(LrInvPsi2_rLrInvT) \ 109 + beta* tr_LcInvPsi2_cLcInvT_LcInvScLcInvT * LrInvPsi2_rLrInvT.dot(LrInvSrLrInvT) \ 110 - beta* LrInvPsi1_rT.dot(Y.T).dot(LcInvPsi1_cT.T).dot(LcInvMLrInvT) \ 111 - beta/2. * tr_LcInvPsi2_cLcInvT * LrInvPsi2_rLrInvT 112 113 dL_dKuu_r = backsub_both_sides(Lr, tmp, 'left') 114 dL_dKuu_r += dL_dKuu_r.T 115 dL_dKuu_r *= 0.5 116 117 #====================================================================== 118 # Compute dL_dthetaL 119 #====================================================================== 120 121 dL_dthetaL = -D*N*beta/2. - logL_A*beta*beta/2. 122 123 #====================================================================== 124 # Compute dL_dqU 125 #====================================================================== 126 127 tmp = -beta * LcInvPsi2_cLcInvT.dot(LcInvMLrInvT).dot(LrInvPsi2_rLrInvT)\ 128 + beta* LcInvPsi1_cT.dot(Y).dot(LrInvPsi1_rT.T) 129 130 dL_dqU_mean = dtrtrs(Lc, dtrtrs(Lr, tmp.T, trans=1)[0].T, trans=1)[0] 131 132 tmp = -beta/2.*tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT * LcInvPsi2_cLcInvT 133 dL_dqU_var_c = backsub_both_sides(Lc, tmp, 'left') 134 135 tmp = -beta/2.*tr_LcInvPsi2_cLcInvT_LcInvScLcInvT * LrInvPsi2_rLrInvT 136 dL_dqU_var_r = backsub_both_sides(Lr, tmp, 'left') 137 138 #====================================================================== 139 # Compute dL_dpsi 140 #====================================================================== 141 142 dL_dpsi0_r = - psi0_c * beta/2. * np.ones((D,)) 143 dL_dpsi0_c = - psi0_r * beta/2. * np.ones((N,)) 144 145 dL_dpsi1_c = beta * dtrtrs(Lc, (Y.dot(LrInvPsi1_rT.T).dot(LcInvMLrInvT.T)).T, trans=1)[0].T 146 dL_dpsi1_r = beta * dtrtrs(Lr, (Y.T.dot(LcInvPsi1_cT.T).dot(LcInvMLrInvT)).T, trans=1)[0].T 147 148 tmp = beta/2.*(-LcInvMLrInvT.dot(LrInvPsi2_rLrInvT).dot(LcInvMLrInvT.T) - tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT * LcInvScLcInvT 149 +tr_LrInvPsi2_rLrInvT *np.eye(Mc)) 150 dL_dpsi2_c = backsub_both_sides(Lc, tmp, 'left') 151 tmp = beta/2.*(-LcInvMLrInvT.T.dot(LcInvPsi2_cLcInvT).dot(LcInvMLrInvT) - tr_LcInvPsi2_cLcInvT_LcInvScLcInvT * LrInvSrLrInvT 152 +tr_LcInvPsi2_cLcInvT *np.eye(Mr)) 153 dL_dpsi2_r = backsub_both_sides(Lr, tmp, 'left') 154 155 grad_dict['dL_dthetaL'][d:d+1] = dL_dthetaL 156 grad_dict['dL_dqU_mean'] += dL_dqU_mean 157 grad_dict['dL_dqU_var_c'] += dL_dqU_var_c 158 grad_dict['dL_dqU_var_r'] += dL_dqU_var_r 159 grad_dict['dL_dKuu_c'] += dL_dKuu_c 160 grad_dict['dL_dKuu_r'] += dL_dKuu_r 161 162 # if not uncertain_inputs_r: 163 # dL_dpsi1_r += (dL_dpsi2_r * psi1_r[:,:,None]).sum(1) + (dL_dpsi2_r * psi1_r[:,None,:]).sum(2) 164 # if not uncertain_inputs_c: 165 # dL_dpsi1_c += (dL_dpsi2_c * psi1_c[:,:,None]).sum(1) + (dL_dpsi2_c * psi1_c[:,None,:]).sum(2) 166 167 if not uncertain_inputs_r: 168 dL_dpsi1_r += psi1_r.dot(dL_dpsi2_r+dL_dpsi2_r.T) 169 if not uncertain_inputs_c: 170 dL_dpsi1_c += psi1_c.dot(dL_dpsi2_c+dL_dpsi2_c.T) 171 172 grad_dict['dL_dpsi0_c'][idx_d] += dL_dpsi0_c 173 grad_dict['dL_dpsi1_c'][idx_d] += dL_dpsi1_c 174 grad_dict['dL_dpsi2_c'][idx_d] += dL_dpsi2_c 175 176 grad_dict['dL_dpsi0_r'][d:d+1] += dL_dpsi0_r 177 grad_dict['dL_dpsi1_r'][d:d+1] += dL_dpsi1_r 178 grad_dict['dL_dpsi2_r'][d] += dL_dpsi2_r 179 180 181 return logL 182 183 184 def inference(self, kern_r, kern_c, Xr, Xc, Zr, Zc, likelihood, Y, qU_mean ,qU_var_r, qU_var_c, indexD, output_dim): 185 """ 186 The SVI-VarDTC inference 187 """ 188 189 N, D, Mr, Mc, Qr, Qc = Y.shape[0], output_dim,Zr.shape[0], Zc.shape[0], Zr.shape[1], Zc.shape[1] 190 191 uncertain_inputs_r = isinstance(Xr, VariationalPosterior) 192 uncertain_inputs_c = isinstance(Xc, VariationalPosterior) 193 uncertain_outputs = isinstance(Y, VariationalPosterior) 194 195 grad_dict = self._init_grad_dict(N,D,Mr,Mc) 196 197 beta = 1./likelihood.variance 198 if len(beta)==1: 199 beta = np.zeros(D)+beta 200 201 psi0_r, psi1_r, psi2_r = self.gatherPsiStat(kern_r, Xr, Zr, uncertain_inputs_r) 202 psi0_c, psi1_c, psi2_c = self.gatherPsiStat(kern_c, Xc, Zc, uncertain_inputs_c) 203 204 #====================================================================== 205 # Compute Common Components 206 #====================================================================== 207 208 Kuu_r = kern_r.K(Zr).copy() 209 diag.add(Kuu_r, self.const_jitter) 210 Lr = jitchol(Kuu_r) 211 212 Kuu_c = kern_c.K(Zc).copy() 213 diag.add(Kuu_c, self.const_jitter) 214 Lc = jitchol(Kuu_c) 215 216 mu, Sr, Sc = qU_mean, qU_var_r, qU_var_c 217 LSr = jitchol(Sr) 218 LSc = jitchol(Sc) 219 220 LcInvMLrInvT = dtrtrs(Lc,dtrtrs(Lr,mu.T)[0].T)[0] 221 LcInvLSc = dtrtrs(Lc, LSc)[0] 222 LrInvLSr = dtrtrs(Lr, LSr)[0] 223 LcInvScLcInvT = tdot(LcInvLSc) 224 LrInvSrLrInvT = tdot(LrInvLSr) 225 tr_LrInvSrLrInvT = np.square(LrInvLSr).sum() 226 tr_LcInvScLcInvT = np.square(LcInvLSc).sum() 227 228 mid_res = { 229 'psi0_r': psi0_r, 230 'psi1_r': psi1_r, 231 'psi2_r': psi2_r, 232 'psi0_c': psi0_c, 233 'psi1_c': psi1_c, 234 'psi2_c': psi2_c, 235 'Lr':Lr, 236 'Lc':Lc, 237 'LcInvMLrInvT': LcInvMLrInvT, 238 'LcInvScLcInvT': LcInvScLcInvT, 239 'LrInvSrLrInvT': LrInvSrLrInvT, 240 } 241 242 #====================================================================== 243 # Compute log-likelihood 244 #====================================================================== 245 246 logL = 0. 247 for d in range(D): 248 logL += self.inference_d(d, beta, Y, indexD, grad_dict, mid_res, uncertain_inputs_r, uncertain_inputs_c, Mr, Mc) 249 250 logL += -Mc * (np.log(np.diag(Lr)).sum()-np.log(np.diag(LSr)).sum()) -Mr * (np.log(np.diag(Lc)).sum()-np.log(np.diag(LSc)).sum()) \ 251 - np.square(LcInvMLrInvT).sum()/2. - tr_LrInvSrLrInvT * tr_LcInvScLcInvT/2. + Mr*Mc/2. 252 253 #====================================================================== 254 # Compute dL_dKuu 255 #====================================================================== 256 257 tmp = tdot(LcInvMLrInvT)/2. + tr_LrInvSrLrInvT/2. * LcInvScLcInvT - Mr/2.*np.eye(Mc) 258 259 dL_dKuu_c = backsub_both_sides(Lc, tmp, 'left') 260 dL_dKuu_c += dL_dKuu_c.T 261 dL_dKuu_c *= 0.5 262 263 tmp = tdot(LcInvMLrInvT.T)/2. + tr_LcInvScLcInvT/2. * LrInvSrLrInvT - Mc/2.*np.eye(Mr) 264 265 dL_dKuu_r = backsub_both_sides(Lr, tmp, 'left') 266 dL_dKuu_r += dL_dKuu_r.T 267 dL_dKuu_r *= 0.5 268 269 #====================================================================== 270 # Compute dL_dqU 271 #====================================================================== 272 273 tmp = - LcInvMLrInvT 274 dL_dqU_mean = dtrtrs(Lc, dtrtrs(Lr, tmp.T, trans=1)[0].T, trans=1)[0] 275 276 LScInv = dtrtri(LSc) 277 tmp = -tr_LrInvSrLrInvT/2.*np.eye(Mc) 278 dL_dqU_var_c = backsub_both_sides(Lc, tmp, 'left') + tdot(LScInv.T) * Mr/2. 279 280 LSrInv = dtrtri(LSr) 281 tmp = -tr_LcInvScLcInvT/2.*np.eye(Mr) 282 dL_dqU_var_r = backsub_both_sides(Lr, tmp, 'left') + tdot(LSrInv.T) * Mc/2. 283 284 #====================================================================== 285 # Compute the Posterior distribution of inducing points p(u|Y) 286 #====================================================================== 287 288 post = PosteriorMultioutput(LcInvMLrInvT=LcInvMLrInvT, LcInvScLcInvT=LcInvScLcInvT, 289 LrInvSrLrInvT=LrInvSrLrInvT, Lr=Lr, Lc=Lc, kern_r=kern_r, Xr=Xr, Zr=Zr) 290 291 #====================================================================== 292 # Compute dL_dpsi 293 #====================================================================== 294 295 grad_dict['dL_dqU_mean'] += dL_dqU_mean 296 grad_dict['dL_dqU_var_c'] += dL_dqU_var_c 297 grad_dict['dL_dqU_var_r'] += dL_dqU_var_r 298 grad_dict['dL_dKuu_c'] += dL_dKuu_c 299 grad_dict['dL_dKuu_r'] += dL_dKuu_r 300 301 if not uncertain_inputs_c: 302 grad_dict['dL_dKdiag_c'] = grad_dict['dL_dpsi0_c'] 303 grad_dict['dL_dKfu_c'] = grad_dict['dL_dpsi1_c'] 304 305 if not uncertain_inputs_r: 306 grad_dict['dL_dKdiag_r'] = grad_dict['dL_dpsi0_r'] 307 grad_dict['dL_dKfu_r'] = grad_dict['dL_dpsi1_r'] 308 309 return post, logL, grad_dict 310