1# Copyright (c) 2012-2014, GPy authors (see AUTHORS.txt). 2# Licensed under the BSD 3-clause license (see LICENSE.txt) 3import numpy as np 4from ...util.linalg import jitchol, DSYR, dtrtrs, dtrtri, pdinv, dpotrs, tdot, symmetrify 5from paramz import ObsAr 6from . import ExactGaussianInference, VarDTC 7from ...util import diag 8from .posterior import PosteriorEP as Posterior 9from ...likelihoods import Gaussian 10from . import LatentFunctionInference 11 12log_2_pi = np.log(2*np.pi) 13 14 15#Four wrapper classes to help modularisation of different EP versions 16class marginalMoments(object): 17 def __init__(self, num_data): 18 self.Z_hat = np.empty(num_data,dtype=np.float64) 19 self.mu_hat = np.empty(num_data,dtype=np.float64) 20 self.sigma2_hat = np.empty(num_data,dtype=np.float64) 21 22 23class cavityParams(object): 24 def __init__(self, num_data): 25 self.tau = np.empty(num_data,dtype=np.float64) 26 self.v = np.empty(num_data,dtype=np.float64) 27 def _update_i(self, eta, ga_approx, post_params, i): 28 self.tau[i] = 1./post_params.Sigma_diag[i] - eta*ga_approx.tau[i] 29 self.v[i] = post_params.mu[i]/post_params.Sigma_diag[i] - eta*ga_approx.v[i] 30 def to_dict(self): 31 """ 32 Convert the object into a json serializable dictionary. 33 34 Note: It uses the private method _save_to_input_dict of the parent. 35 36 :return dict: json serializable dictionary containing the needed information to instantiate the object 37 """ 38 39 return {"tau": self.tau.tolist(), "v": self.v.tolist()} 40 @staticmethod 41 def from_dict(input_dict): 42 c = cavityParams(len(input_dict["tau"])) 43 c.tau = np.array(input_dict["tau"]) 44 c.v = np.array(input_dict["v"]) 45 return c 46 47 48class gaussianApproximation(object): 49 def __init__(self, v, tau): 50 self.tau = tau 51 self.v = v 52 def _update_i(self, eta, delta, post_params, marg_moments, i): 53 #Site parameters update 54 delta_tau = delta/eta*(1./marg_moments.sigma2_hat[i] - 1./post_params.Sigma_diag[i]) 55 delta_v = delta/eta*(marg_moments.mu_hat[i]/marg_moments.sigma2_hat[i] - post_params.mu[i]/post_params.Sigma_diag[i]) 56 tau_tilde_prev = self.tau[i] 57 self.tau[i] += delta_tau 58 59 # Enforce positivity of tau_tilde. Even though this is guaranteed for logconcave sites, it is still possible 60 # to get negative values due to numerical errors. Moreover, the value of tau_tilde should be positive in order to 61 # update the marginal likelihood without runnint into instabilities issues. 62 if self.tau[i] < np.finfo(float).eps: 63 self.tau[i] = np.finfo(float).eps 64 delta_tau = self.tau[i] - tau_tilde_prev 65 66 self.v[i] += delta_v 67 68 return (delta_tau, delta_v) 69 def to_dict(self): 70 """ 71 Convert the object into a json serializable dictionary. 72 73 Note: It uses the private method _save_to_input_dict of the parent. 74 75 :return dict: json serializable dictionary containing the needed information to instantiate the object 76 """ 77 78 return {"tau": self.tau.tolist(), "v": self.v.tolist()} 79 @staticmethod 80 def from_dict(input_dict): 81 return gaussianApproximation(np.array(input_dict["v"]), np.array(input_dict["tau"])) 82 83 84class posteriorParamsBase(object): 85 def __init__(self, mu, Sigma_diag): 86 self.mu = mu 87 self.Sigma_diag = Sigma_diag 88 def _update_rank1(self, *arg): 89 pass 90 91 def _recompute(self, *arg): 92 pass 93 94class posteriorParams(posteriorParamsBase): 95 def __init__(self, mu, Sigma, L=None): 96 self.Sigma = Sigma 97 self.L = L 98 Sigma_diag = np.diag(self.Sigma) 99 super(posteriorParams, self).__init__(mu, Sigma_diag) 100 101 def _update_rank1(self, delta_tau, delta_v, ga_approx, i): 102 si = self.Sigma[i,:].copy() 103 ci = delta_tau/(1.+ delta_tau*si[i]) 104 self.mu = self.mu - (ci*(self.mu[i]+si[i]*delta_v)-delta_v) * si 105 DSYR(self.Sigma, si, -ci) 106 107 def to_dict(self): 108 """ 109 Convert the object into a json serializable dictionary. 110 111 Note: It uses the private method _save_to_input_dict of the parent. 112 113 :return dict: json serializable dictionary containing the needed information to instantiate the object 114 """ 115 116 #TODO: Implement a more memory efficient variant 117 if self.L is None: 118 return { "mu": self.mu.tolist(), "Sigma": self.Sigma.tolist()} 119 else: 120 return { "mu": self.mu.tolist(), "Sigma": self.Sigma.tolist(), "L": self.L.tolist()} 121 122 @staticmethod 123 def from_dict(input_dict): 124 if "L" in input_dict: 125 return posteriorParams(np.array(input_dict["mu"]), np.array(input_dict["Sigma"]), np.array(input_dict["L"])) 126 else: 127 return posteriorParams(np.array(input_dict["mu"]), np.array(input_dict["Sigma"])) 128 129 @staticmethod 130 def _recompute(mean_prior, K, ga_approx): 131 num_data = len(ga_approx.tau) 132 tau_tilde_root = np.sqrt(ga_approx.tau) 133 Sroot_tilde_K = tau_tilde_root[:,None] * K 134 B = np.eye(num_data) + Sroot_tilde_K * tau_tilde_root[None,:] 135 L = jitchol(B) 136 V, _ = dtrtrs(L, Sroot_tilde_K, lower=1) 137 Sigma = K - np.dot(V.T,V) #K - KS^(1/2)BS^(1/2)K = (K^(-1) + \Sigma^(-1))^(-1) 138 139 aux_alpha , _ = dpotrs(L, tau_tilde_root * (np.dot(K, ga_approx.v) + mean_prior), lower=1) 140 alpha = ga_approx.v - tau_tilde_root * aux_alpha #(K + Sigma^(\tilde))^(-1) (/mu^(/tilde) - /mu_p) 141 mu = np.dot(K, alpha) + mean_prior 142 143 return posteriorParams(mu=mu, Sigma=Sigma, L=L) 144 145class posteriorParamsDTC(posteriorParamsBase): 146 def __init__(self, mu, Sigma_diag): 147 super(posteriorParamsDTC, self).__init__(mu, Sigma_diag) 148 149 def _update_rank1(self, LLT, Kmn, delta_v, delta_tau, i): 150 #DSYR(Sigma, Sigma[:,i].copy(), -delta_tau/(1.+ delta_tau*Sigma[i,i])) 151 DSYR(LLT,Kmn[:,i].copy(),delta_tau) 152 L = jitchol(LLT) 153 V,info = dtrtrs(L,Kmn,lower=1) 154 self.Sigma_diag = np.maximum(np.sum(V*V,-2), np.finfo(float).eps) #diag(K_nm (L L^\top)^(-1)) K_mn 155 si = np.sum(V.T*V[:,i],-1) #(V V^\top)[:,i] 156 self.mu += (delta_v-delta_tau*self.mu[i])*si 157 #mu = np.dot(Sigma, v_tilde) 158 159 def to_dict(self): 160 """ 161 Convert the object into a json serializable dictionary. 162 163 Note: It uses the private method _save_to_input_dict of the parent. 164 165 :return dict: json serializable dictionary containing the needed information to instantiate the object 166 """ 167 168 return { "mu": self.mu.tolist(), "Sigma_diag": self.Sigma_diag.tolist()} 169 170 @staticmethod 171 def from_dict(input_dict): 172 return posteriorParamsDTC(np.array(input_dict["mu"]), np.array(input_dict["Sigma_diag"])) 173 174 @staticmethod 175 def _recompute(LLT0, Kmn, ga_approx): 176 LLT = LLT0 + np.dot(Kmn*ga_approx.tau[None,:],Kmn.T) 177 L = jitchol(LLT) 178 V, _ = dtrtrs(L,Kmn,lower=1) 179 #Sigma_diag = np.sum(V*V,-2) 180 #Knmv_tilde = np.dot(Kmn,v_tilde) 181 #mu = np.dot(V2.T,Knmv_tilde) 182 Sigma = np.dot(V.T,V) 183 mu = np.dot(Sigma, ga_approx.v) 184 Sigma_diag = np.diag(Sigma).copy() 185 return posteriorParamsDTC(mu, Sigma_diag), LLT 186 187class EPBase(object): 188 def __init__(self, epsilon=1e-6, eta=1., delta=1., always_reset=False, max_iters=np.inf, ep_mode="alternated", parallel_updates=False, loading=False): 189 """ 190 The expectation-propagation algorithm. 191 For nomenclature see Rasmussen & Williams 2006. 192 193 :param epsilon: Convergence criterion, maximum squared difference allowed between mean updates to stop iterations (float) 194 :type epsilon: float 195 :param eta: parameter for fractional EP updates. 196 :type eta: float64 197 :param delta: damping EP updates factor. 198 :type delta: float64 199 :param always_reset: setting to always reset the approximation at the beginning of every inference call. 200 :type always_reest: boolean 201 :max_iters: int 202 :ep_mode: string. It can be "nested" (EP is run every time the Hyperparameters change) or "alternated" (It runs EP at the beginning and then optimize the Hyperparameters). 203 :parallel_updates: boolean. If true, updates of the parameters of the sites in parallel 204 :loading: boolean. If True, prevents the EP parameters to change. Hack used when loading a serialized model 205 """ 206 super(EPBase, self).__init__() 207 208 self.always_reset = always_reset 209 self.epsilon, self.eta, self.delta, self.max_iters = epsilon, eta, delta, max_iters 210 self.ep_mode = ep_mode 211 self.parallel_updates = parallel_updates 212 #FIXME: Hack for serialiation. If True, prevents the EP parameters to change when loading a serialized model 213 self.loading = loading 214 self.reset() 215 216 def reset(self): 217 self.ga_approx_old = None 218 self._ep_approximation = None 219 220 def on_optimization_start(self): 221 self._ep_approximation = None 222 223 def on_optimization_end(self): 224 # TODO: update approximation in the end as well? Maybe even with a switch? 225 pass 226 227 def _stop_criteria(self, ga_approx): 228 tau_diff = np.mean(np.square(ga_approx.tau-self.ga_approx_old.tau)) 229 v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v)) 230 return ((tau_diff < self.epsilon) and (v_diff < self.epsilon)) 231 232 def __setstate__(self, state): 233 super(EPBase, self).__setstate__(state[0]) 234 self.epsilon, self.eta, self.delta = state[1] 235 self.reset() 236 237 def __getstate__(self): 238 return [super(EPBase, self).__getstate__() , [self.epsilon, self.eta, self.delta]] 239 240 def _save_to_input_dict(self): 241 input_dict = super(EPBase, self)._save_to_input_dict() 242 input_dict["epsilon"]=self.epsilon 243 input_dict["eta"]=self.eta 244 input_dict["delta"]=self.delta 245 input_dict["always_reset"]=self.always_reset 246 input_dict["max_iters"]=self.max_iters 247 input_dict["ep_mode"]=self.ep_mode 248 input_dict["parallel_updates"]=self.parallel_updates 249 input_dict["loading"]=True 250 return input_dict 251 252class EP(EPBase, ExactGaussianInference): 253 def inference(self, kern, X, likelihood, Y, mean_function=None, Y_metadata=None, precision=None, K=None): 254 if self.always_reset and not self.loading: 255 self.reset() 256 257 num_data, output_dim = Y.shape 258 assert output_dim == 1, "ep in 1D only (for now!)" 259 260 if mean_function is None: 261 mean_prior = np.zeros(X.shape[0]) 262 else: 263 mean_prior = mean_function.f(X).flatten() 264 265 if K is None: 266 K = kern.K(X) 267 268 if self.ep_mode=="nested" and not self.loading: 269 #Force EP at each step of the optimization 270 self._ep_approximation = None 271 post_params, ga_approx, cav_params, log_Z_tilde = self._ep_approximation = self.expectation_propagation(mean_prior, K, Y, likelihood, Y_metadata) 272 elif self.ep_mode=="alternated" or self.loading: 273 if getattr(self, '_ep_approximation', None) is None: 274 #if we don't yet have the results of runnign EP, run EP and store the computed factors in self._ep_approximation 275 post_params, ga_approx, cav_params, log_Z_tilde = self._ep_approximation = self.expectation_propagation(mean_prior, K, Y, likelihood, Y_metadata) 276 else: 277 #if we've already run EP, just use the existing approximation stored in self._ep_approximation 278 post_params, ga_approx, cav_params, log_Z_tilde = self._ep_approximation 279 else: 280 raise ValueError("ep_mode value not valid") 281 282 self.loading = False 283 284 return self._inference(Y, mean_prior, K, ga_approx, cav_params, likelihood, Y_metadata=Y_metadata, Z_tilde=log_Z_tilde) 285 286 def expectation_propagation(self, mean_prior, K, Y, likelihood, Y_metadata): 287 288 num_data, data_dim = Y.shape 289 assert data_dim == 1, "This EP methods only works for 1D outputs" 290 291 # Makes computing the sign quicker if we work with numpy arrays rather 292 # than ObsArrays 293 Y = Y.values.copy() 294 295 #Initial values - Marginal moments, cavity params, gaussian approximation params and posterior params 296 marg_moments = marginalMoments(num_data) 297 cav_params = cavityParams(num_data) 298 ga_approx, post_params = self._init_approximations(mean_prior, K, num_data) 299 300 #Approximation 301 stop = False 302 iterations = 0 303 while not stop and (iterations < self.max_iters): 304 self._local_updates(num_data, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata) 305 306 #(re) compute Sigma and mu using full Cholesky decompy 307 post_params = posteriorParams._recompute(mean_prior, K, ga_approx) 308 309 #monitor convergence 310 if iterations > 0: 311 stop = self._stop_criteria(ga_approx) 312 self.ga_approx_old = gaussianApproximation(ga_approx.v.copy(), ga_approx.tau.copy()) 313 iterations += 1 314 315 log_Z_tilde = self._log_Z_tilde(marg_moments, ga_approx, cav_params) 316 317 return (post_params, ga_approx, cav_params, log_Z_tilde) 318 319 def _init_approximations(self, mean_prior, K, num_data): 320 #initial values - Gaussian factors 321 #Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma) 322 if self.ga_approx_old is None: 323 v_tilde, tau_tilde = np.zeros((2, num_data)) 324 ga_approx = gaussianApproximation(v_tilde, tau_tilde) 325 Sigma = K.copy() 326 diag.add(Sigma, 1e-7) 327 mu = mean_prior 328 post_params = posteriorParams(mu, Sigma) 329 else: 330 assert self.ga_approx_old.v.size == num_data, "data size mis-match: did you change the data? try resetting!" 331 ga_approx = gaussianApproximation(self.ga_approx_old.v, self.ga_approx_old.tau) 332 post_params = posteriorParams._recompute(mean_prior, K, ga_approx) 333 diag.add(post_params.Sigma, 1e-7) 334 # TODO: Check the log-marginal under both conditions and choose the best one 335 return (ga_approx, post_params) 336 337 def _local_updates(self, num_data, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata, update_order=None): 338 if update_order is None: 339 update_order = np.random.permutation(num_data) 340 for i in update_order: 341 #Cavity distribution parameters 342 cav_params._update_i(self.eta, ga_approx, post_params, i) 343 344 if Y_metadata is not None: 345 # Pick out the relavent metadata for Yi 346 Y_metadata_i = {} 347 for key in Y_metadata.keys(): 348 Y_metadata_i[key] = Y_metadata[key][i, :] 349 else: 350 Y_metadata_i = None 351 #Marginal moments 352 marg_moments.Z_hat[i], marg_moments.mu_hat[i], marg_moments.sigma2_hat[i] = likelihood.moments_match_ep(Y[i], cav_params.tau[i], cav_params.v[i], Y_metadata_i=Y_metadata_i) 353 354 #Site parameters update 355 delta_tau, delta_v = ga_approx._update_i(self.eta, self.delta, post_params, marg_moments, i) 356 357 if self.parallel_updates == False: 358 post_params._update_rank1(delta_tau, delta_v, ga_approx, i) 359 360 def _log_Z_tilde(self, marg_moments, ga_approx, cav_params): 361 # Z_tilde after removing the terms that can lead to infinite terms due to tau_tilde close to zero. 362 # This terms cancel with the coreresponding terms in the marginal loglikelihood 363 return np.sum(( 364 np.log(marg_moments.Z_hat) 365 + 0.5*np.log(2*np.pi) + 0.5*np.log(1+ga_approx.tau/cav_params.tau) 366 - 0.5 * ((ga_approx.v)**2 * 1./(cav_params.tau + ga_approx.tau)) 367 + 0.5*(cav_params.v * ( ( (ga_approx.tau/cav_params.tau) * cav_params.v - 2.0 * ga_approx.v ) * 1./(cav_params.tau + ga_approx.tau))) 368 )) 369 370 def _ep_marginal(self, mean_prior, K, ga_approx, Z_tilde): 371 post_params = posteriorParams._recompute(mean_prior, K, ga_approx) 372 # Gaussian log marginal excluding terms that can go to infinity due to arbitrarily small tau_tilde. 373 # These terms cancel out with the terms excluded from Z_tilde 374 B_logdet = np.sum(2.0*np.log(np.diag(post_params.L))) 375 S_mean_prior = ga_approx.tau * mean_prior 376 v_centered = ga_approx.v - S_mean_prior 377 log_marginal = 0.5*( 378 -len(ga_approx.tau) * log_2_pi - B_logdet 379 + np.sum(v_centered * np.dot(post_params.Sigma, v_centered)) 380 - np.dot(mean_prior, (S_mean_prior - 2*ga_approx.v)) 381 ) 382 log_marginal += Z_tilde 383 384 return log_marginal, post_params 385 386 def _inference(self, Y, mean_prior, K, ga_approx, cav_params, likelihood, Z_tilde, Y_metadata=None): 387 log_marginal, post_params = self._ep_marginal(mean_prior, K, ga_approx, Z_tilde) 388 389 tau_tilde_root = np.sqrt(ga_approx.tau) 390 Sroot_tilde_K = tau_tilde_root[:,None] * K 391 392 393 aux_alpha , _ = dpotrs(post_params.L, tau_tilde_root * (np.dot(K, ga_approx.v) + mean_prior), lower=1) 394 alpha = (ga_approx.v - tau_tilde_root * aux_alpha)[:,None] #(K + Sigma^(\tilde))^(-1) (/mu^(/tilde) - /mu_p) 395 396 LWi, _ = dtrtrs(post_params.L, np.diag(tau_tilde_root), lower=1) 397 Wi = np.dot(LWi.T,LWi) 398 symmetrify(Wi) #(K + Sigma^(\tilde))^(-1) 399 400 dL_dK = 0.5 * (tdot(alpha) - Wi) 401 dL_dthetaL = likelihood.ep_gradients(Y, cav_params.tau, cav_params.v, np.diag(dL_dK), Y_metadata=Y_metadata, quad_mode='gh') 402 return Posterior(woodbury_inv=Wi, woodbury_vector=alpha, K=K), log_marginal, {'dL_dK':dL_dK, 'dL_dthetaL':dL_dthetaL, 'dL_dm':alpha} 403 404 def to_dict(self): 405 """ 406 Convert the object into a json serializable dictionary. 407 408 Note: It uses the private method _save_to_input_dict of the parent. 409 410 :return dict: json serializable dictionary containing the needed information to instantiate the object 411 """ 412 413 input_dict = super(EP, self)._save_to_input_dict() 414 input_dict["class"] = "GPy.inference.latent_function_inference.expectation_propagation.EP" 415 if self.ga_approx_old is not None: 416 input_dict["ga_approx_old"] = self.ga_approx_old.to_dict() 417 if self._ep_approximation is not None: 418 input_dict["_ep_approximation"] = {} 419 input_dict["_ep_approximation"]["post_params"] = self._ep_approximation[0].to_dict() 420 input_dict["_ep_approximation"]["ga_approx"] = self._ep_approximation[1].to_dict() 421 input_dict["_ep_approximation"]["cav_params"] = self._ep_approximation[2].to_dict() 422 input_dict["_ep_approximation"]["log_Z_tilde"] = self._ep_approximation[3].tolist() 423 424 return input_dict 425 426 @staticmethod 427 def _build_from_input_dict(inference_class, input_dict): 428 ga_approx_old = input_dict.pop('ga_approx_old', None) 429 if ga_approx_old is not None: 430 ga_approx_old = gaussianApproximation.from_dict(ga_approx_old) 431 _ep_approximation_dict = input_dict.pop('_ep_approximation', None) 432 _ep_approximation = [] 433 if _ep_approximation is not None: 434 _ep_approximation.append(posteriorParams.from_dict(_ep_approximation_dict["post_params"])) 435 _ep_approximation.append(gaussianApproximation.from_dict(_ep_approximation_dict["ga_approx"])) 436 _ep_approximation.append(cavityParams.from_dict(_ep_approximation_dict["cav_params"])) 437 _ep_approximation.append(np.array(_ep_approximation_dict["log_Z_tilde"])) 438 ee = EP(**input_dict) 439 ee.ga_approx_old = ga_approx_old 440 ee._ep_approximation = _ep_approximation 441 return ee 442 443class EPDTC(EPBase, VarDTC): 444 def inference(self, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, Lm=None, dL_dKmm=None, psi0=None, psi1=None, psi2=None): 445 if self.always_reset and not self.loading: 446 self.reset() 447 448 num_data, output_dim = Y.shape 449 assert output_dim == 1, "ep in 1D only (for now!)" 450 451 if Lm is None: 452 Kmm = kern.K(Z) 453 Lm = jitchol(Kmm) 454 455 if psi1 is None: 456 try: 457 Kmn = kern.K(Z, X) 458 except TypeError: 459 Kmn = kern.psi1(Z, X).T 460 else: 461 Kmn = psi1.T 462 463 if self.ep_mode=="nested" and not self.loading: 464 #Force EP at each step of the optimization 465 self._ep_approximation = None 466 post_params, ga_approx, log_Z_tilde = self._ep_approximation = self.expectation_propagation(Kmm, Kmn, Y, likelihood, Y_metadata) 467 elif self.ep_mode=="alternated" or self.loading: 468 if getattr(self, '_ep_approximation', None) is None: 469 #if we don't yet have the results of runnign EP, run EP and store the computed factors in self._ep_approximation 470 post_params, ga_approx, log_Z_tilde = self._ep_approximation = self.expectation_propagation(Kmm, Kmn, Y, likelihood, Y_metadata) 471 else: 472 #if we've already run EP, just use the existing approximation stored in self._ep_approximation 473 post_params, ga_approx, log_Z_tilde = self._ep_approximation 474 else: 475 raise ValueError("ep_mode value not valid") 476 477 self.loading = False 478 479 mu_tilde = ga_approx.v / ga_approx.tau.astype(float) 480 481 return super(EPDTC, self).inference(kern, X, Z, likelihood, ObsAr(mu_tilde[:,None]), 482 mean_function=mean_function, 483 Y_metadata=Y_metadata, 484 precision=ga_approx.tau, 485 Lm=Lm, dL_dKmm=dL_dKmm, 486 psi0=psi0, psi1=psi1, psi2=psi2, Z_tilde=log_Z_tilde) 487 488 def expectation_propagation(self, Kmm, Kmn, Y, likelihood, Y_metadata): 489 490 num_data, output_dim = Y.shape 491 assert output_dim == 1, "This EP methods only works for 1D outputs" 492 493 # Makes computing the sign quicker if we work with numpy arrays rather 494 # than ObsArrays 495 Y = Y.values.copy() 496 497 #Initial values - Marginal moments, cavity params, gaussian approximation params and posterior params 498 marg_moments = marginalMoments(num_data) 499 cav_params = cavityParams(num_data) 500 ga_approx, post_params, LLT0, LLT = self._init_approximations(Kmm, Kmn, num_data) 501 502 #Approximation 503 stop = False 504 iterations = 0 505 while not stop and (iterations < self.max_iters): 506 self._local_updates(num_data, LLT0, LLT, Kmn, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata) 507 #(re) compute Sigma, Sigma_diag and mu using full Cholesky decompy 508 post_params, LLT = posteriorParamsDTC._recompute(LLT0, Kmn, ga_approx) 509 post_params.Sigma_diag = np.maximum(post_params.Sigma_diag, np.finfo(float).eps) 510 511 #monitor convergence 512 if iterations > 0: 513 stop = self._stop_criteria(ga_approx) 514 self.ga_approx_old = gaussianApproximation(ga_approx.v.copy(), ga_approx.tau.copy()) 515 iterations += 1 516 517 log_Z_tilde = self._log_Z_tilde(marg_moments, ga_approx, cav_params) 518 519 return post_params, ga_approx, log_Z_tilde 520 521 def _log_Z_tilde(self, marg_moments, ga_approx, cav_params): 522 mu_tilde = ga_approx.v/ga_approx.tau 523 mu_cav = cav_params.v/cav_params.tau 524 sigma2_sigma2tilde = 1./cav_params.tau + 1./ga_approx.tau 525 526 return np.sum((np.log(marg_moments.Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(sigma2_sigma2tilde) 527 + 0.5*((mu_cav - mu_tilde)**2) / (sigma2_sigma2tilde))) 528 529 def _init_approximations(self, Kmm, Kmn, num_data): 530 #initial values - Gaussian factors 531 #Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma) 532 LLT0 = Kmm.copy() 533 Lm = jitchol(LLT0) #K_m = L_m L_m^\top 534 Vm,info = dtrtrs(Lm, Kmn,lower=1) 535 # Lmi = dtrtri(Lm) 536 # Kmmi = np.dot(Lmi.T,Lmi) 537 # KmmiKmn = np.dot(Kmmi,Kmn) 538 # Qnn_diag = np.sum(Kmn*KmmiKmn,-2) 539 Qnn_diag = np.sum(Vm*Vm,-2) #diag(Knm Kmm^(-1) Kmn) 540 #diag.add(LLT0, 1e-8) 541 if self.ga_approx_old is None: 542 #Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma) 543 LLT = LLT0.copy() #Sigma = K.copy() 544 mu = np.zeros(num_data) 545 Sigma_diag = Qnn_diag.copy() + 1e-8 546 v_tilde, tau_tilde = np.zeros((2, num_data)) 547 ga_approx = gaussianApproximation(v_tilde, tau_tilde) 548 post_params = posteriorParamsDTC(mu, Sigma_diag) 549 550 else: 551 assert self.ga_approx_old.v.size == num_data, "data size mis-match: did you change the data? try resetting!" 552 ga_approx = gaussianApproximation(self.ga_approx_old.v, self.ga_approx_old.tau) 553 post_params, LLT = posteriorParamsDTC._recompute(LLT0, Kmn, ga_approx) 554 post_params.Sigma_diag += 1e-8 555 556 # TODO: Check the log-marginal under both conditions and choose the best one 557 558 return (ga_approx, post_params, LLT0, LLT) 559 560 def _local_updates(self, num_data, LLT0, LLT, Kmn, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata, update_order=None): 561 if update_order is None: 562 update_order = np.random.permutation(num_data) 563 for i in update_order: 564 565 #Cavity distribution parameters 566 cav_params._update_i(self.eta, ga_approx, post_params, i) 567 568 569 if Y_metadata is not None: 570 # Pick out the relavent metadata for Yi 571 Y_metadata_i = {} 572 for key in Y_metadata.keys(): 573 Y_metadata_i[key] = Y_metadata[key][i, :] 574 else: 575 Y_metadata_i = None 576 577 #Marginal moments 578 marg_moments.Z_hat[i], marg_moments.mu_hat[i], marg_moments.sigma2_hat[i] = likelihood.moments_match_ep(Y[i], cav_params.tau[i], cav_params.v[i], Y_metadata_i=Y_metadata_i) 579 #Site parameters update 580 delta_tau, delta_v = ga_approx._update_i(self.eta, self.delta, post_params, marg_moments, i) 581 582 #Posterior distribution parameters update 583 if self.parallel_updates == False: 584 post_params._update_rank1(LLT, Kmn, delta_v, delta_tau, i) 585 586 587 def to_dict(self): 588 """ 589 Convert the object into a json serializable dictionary. 590 591 Note: It uses the private method _save_to_input_dict of the parent. 592 593 :return dict: json serializable dictionary containing the needed information to instantiate the object 594 """ 595 596 input_dict = super(EPDTC, self)._save_to_input_dict() 597 input_dict["class"] = "GPy.inference.latent_function_inference.expectation_propagation.EPDTC" 598 if self.ga_approx_old is not None: 599 input_dict["ga_approx_old"] = self.ga_approx_old.to_dict() 600 if self._ep_approximation is not None: 601 input_dict["_ep_approximation"] = {} 602 input_dict["_ep_approximation"]["post_params"] = self._ep_approximation[0].to_dict() 603 input_dict["_ep_approximation"]["ga_approx"] = self._ep_approximation[1].to_dict() 604 input_dict["_ep_approximation"]["log_Z_tilde"] = self._ep_approximation[2] 605 606 return input_dict 607 608 @staticmethod 609 def _build_from_input_dict(inference_class, input_dict): 610 ga_approx_old = input_dict.pop('ga_approx_old', None) 611 if ga_approx_old is not None: 612 ga_approx_old = gaussianApproximation.from_dict(ga_approx_old) 613 _ep_approximation_dict = input_dict.pop('_ep_approximation', None) 614 _ep_approximation = [] 615 if _ep_approximation is not None: 616 _ep_approximation.append(posteriorParamsDTC.from_dict(_ep_approximation_dict["post_params"])) 617 _ep_approximation.append(gaussianApproximation.from_dict(_ep_approximation_dict["ga_approx"])) 618 _ep_approximation.append(_ep_approximation_dict["log_Z_tilde"]) 619 ee = EPDTC(**input_dict) 620 ee.ga_approx_old = ga_approx_old 621 ee._ep_approximation = _ep_approximation 622 return ee 623