1# Copyright (c) 2012-2014, GPy authors (see AUTHORS.txt).
2# Licensed under the BSD 3-clause license (see LICENSE.txt)
3import numpy as np
4from ...util.linalg import jitchol, DSYR, dtrtrs, dtrtri, pdinv, dpotrs, tdot, symmetrify
5from paramz import ObsAr
6from . import ExactGaussianInference, VarDTC
7from ...util import diag
8from .posterior import PosteriorEP as Posterior
9from ...likelihoods import Gaussian
10from . import LatentFunctionInference
11
12log_2_pi = np.log(2*np.pi)
13
14
15#Four wrapper classes to help modularisation of different EP versions
16class marginalMoments(object):
17    def __init__(self, num_data):
18        self.Z_hat = np.empty(num_data,dtype=np.float64)
19        self.mu_hat = np.empty(num_data,dtype=np.float64)
20        self.sigma2_hat = np.empty(num_data,dtype=np.float64)
21
22
23class cavityParams(object):
24    def __init__(self, num_data):
25        self.tau = np.empty(num_data,dtype=np.float64)
26        self.v = np.empty(num_data,dtype=np.float64)
27    def _update_i(self, eta, ga_approx, post_params, i):
28        self.tau[i] = 1./post_params.Sigma_diag[i] - eta*ga_approx.tau[i]
29        self.v[i] = post_params.mu[i]/post_params.Sigma_diag[i] - eta*ga_approx.v[i]
30    def to_dict(self):
31        """
32        Convert the object into a json serializable dictionary.
33
34        Note: It uses the private method _save_to_input_dict of the parent.
35
36        :return dict: json serializable dictionary containing the needed information to instantiate the object
37        """
38
39        return {"tau": self.tau.tolist(), "v": self.v.tolist()}
40    @staticmethod
41    def from_dict(input_dict):
42        c = cavityParams(len(input_dict["tau"]))
43        c.tau = np.array(input_dict["tau"])
44        c.v = np.array(input_dict["v"])
45        return c
46
47
48class gaussianApproximation(object):
49    def __init__(self, v, tau):
50        self.tau = tau
51        self.v = v
52    def _update_i(self, eta, delta, post_params, marg_moments, i):
53        #Site parameters update
54        delta_tau = delta/eta*(1./marg_moments.sigma2_hat[i] - 1./post_params.Sigma_diag[i])
55        delta_v = delta/eta*(marg_moments.mu_hat[i]/marg_moments.sigma2_hat[i] - post_params.mu[i]/post_params.Sigma_diag[i])
56        tau_tilde_prev = self.tau[i]
57        self.tau[i] += delta_tau
58
59        # Enforce positivity of tau_tilde. Even though this is guaranteed for logconcave sites, it is still possible
60        # to get negative values due to numerical errors. Moreover, the value of tau_tilde should be positive in order to
61        # update the marginal likelihood without runnint into instabilities issues.
62        if self.tau[i] < np.finfo(float).eps:
63            self.tau[i] = np.finfo(float).eps
64            delta_tau = self.tau[i] - tau_tilde_prev
65
66        self.v[i] += delta_v
67
68        return (delta_tau, delta_v)
69    def to_dict(self):
70        """
71        Convert the object into a json serializable dictionary.
72
73        Note: It uses the private method _save_to_input_dict of the parent.
74
75        :return dict: json serializable dictionary containing the needed information to instantiate the object
76        """
77
78        return {"tau": self.tau.tolist(), "v": self.v.tolist()}
79    @staticmethod
80    def from_dict(input_dict):
81        return gaussianApproximation(np.array(input_dict["v"]), np.array(input_dict["tau"]))
82
83
84class posteriorParamsBase(object):
85    def __init__(self, mu, Sigma_diag):
86        self.mu = mu
87        self.Sigma_diag = Sigma_diag
88    def _update_rank1(self, *arg):
89        pass
90
91    def _recompute(self, *arg):
92        pass
93
94class posteriorParams(posteriorParamsBase):
95    def __init__(self, mu, Sigma, L=None):
96        self.Sigma = Sigma
97        self.L = L
98        Sigma_diag = np.diag(self.Sigma)
99        super(posteriorParams, self).__init__(mu, Sigma_diag)
100
101    def _update_rank1(self, delta_tau, delta_v, ga_approx, i):
102        si = self.Sigma[i,:].copy()
103        ci = delta_tau/(1.+ delta_tau*si[i])
104        self.mu = self.mu - (ci*(self.mu[i]+si[i]*delta_v)-delta_v) * si
105        DSYR(self.Sigma, si, -ci)
106
107    def to_dict(self):
108        """
109        Convert the object into a json serializable dictionary.
110
111        Note: It uses the private method _save_to_input_dict of the parent.
112
113        :return dict: json serializable dictionary containing the needed information to instantiate the object
114        """
115
116        #TODO: Implement a more memory efficient variant
117        if self.L is None:
118            return { "mu": self.mu.tolist(), "Sigma": self.Sigma.tolist()}
119        else:
120            return { "mu": self.mu.tolist(), "Sigma": self.Sigma.tolist(), "L": self.L.tolist()}
121
122    @staticmethod
123    def from_dict(input_dict):
124        if "L" in input_dict:
125            return posteriorParams(np.array(input_dict["mu"]), np.array(input_dict["Sigma"]), np.array(input_dict["L"]))
126        else:
127            return posteriorParams(np.array(input_dict["mu"]), np.array(input_dict["Sigma"]))
128
129    @staticmethod
130    def _recompute(mean_prior, K, ga_approx):
131        num_data = len(ga_approx.tau)
132        tau_tilde_root = np.sqrt(ga_approx.tau)
133        Sroot_tilde_K = tau_tilde_root[:,None] * K
134        B = np.eye(num_data) + Sroot_tilde_K * tau_tilde_root[None,:]
135        L = jitchol(B)
136        V, _ = dtrtrs(L, Sroot_tilde_K, lower=1)
137        Sigma = K - np.dot(V.T,V) #K - KS^(1/2)BS^(1/2)K = (K^(-1) + \Sigma^(-1))^(-1)
138
139        aux_alpha , _ = dpotrs(L, tau_tilde_root * (np.dot(K, ga_approx.v) + mean_prior), lower=1)
140        alpha = ga_approx.v - tau_tilde_root * aux_alpha #(K + Sigma^(\tilde))^(-1) (/mu^(/tilde) - /mu_p)
141        mu = np.dot(K, alpha) + mean_prior
142
143        return posteriorParams(mu=mu, Sigma=Sigma, L=L)
144
145class posteriorParamsDTC(posteriorParamsBase):
146    def __init__(self, mu, Sigma_diag):
147        super(posteriorParamsDTC, self).__init__(mu, Sigma_diag)
148
149    def _update_rank1(self, LLT, Kmn, delta_v, delta_tau, i):
150        #DSYR(Sigma, Sigma[:,i].copy(), -delta_tau/(1.+ delta_tau*Sigma[i,i]))
151        DSYR(LLT,Kmn[:,i].copy(),delta_tau)
152        L = jitchol(LLT)
153        V,info = dtrtrs(L,Kmn,lower=1)
154        self.Sigma_diag = np.maximum(np.sum(V*V,-2), np.finfo(float).eps)  #diag(K_nm (L L^\top)^(-1)) K_mn
155        si = np.sum(V.T*V[:,i],-1) #(V V^\top)[:,i]
156        self.mu += (delta_v-delta_tau*self.mu[i])*si
157        #mu = np.dot(Sigma, v_tilde)
158
159    def to_dict(self):
160        """
161        Convert the object into a json serializable dictionary.
162
163        Note: It uses the private method _save_to_input_dict of the parent.
164
165        :return dict: json serializable dictionary containing the needed information to instantiate the object
166        """
167
168        return { "mu": self.mu.tolist(), "Sigma_diag": self.Sigma_diag.tolist()}
169
170    @staticmethod
171    def from_dict(input_dict):
172        return posteriorParamsDTC(np.array(input_dict["mu"]), np.array(input_dict["Sigma_diag"]))
173
174    @staticmethod
175    def _recompute(LLT0, Kmn, ga_approx):
176        LLT = LLT0 + np.dot(Kmn*ga_approx.tau[None,:],Kmn.T)
177        L = jitchol(LLT)
178        V, _ = dtrtrs(L,Kmn,lower=1)
179        #Sigma_diag = np.sum(V*V,-2)
180        #Knmv_tilde = np.dot(Kmn,v_tilde)
181        #mu = np.dot(V2.T,Knmv_tilde)
182        Sigma = np.dot(V.T,V)
183        mu = np.dot(Sigma, ga_approx.v)
184        Sigma_diag = np.diag(Sigma).copy()
185        return posteriorParamsDTC(mu, Sigma_diag), LLT
186
187class EPBase(object):
188    def __init__(self, epsilon=1e-6, eta=1., delta=1., always_reset=False, max_iters=np.inf, ep_mode="alternated", parallel_updates=False, loading=False):
189        """
190        The expectation-propagation algorithm.
191        For nomenclature see Rasmussen & Williams 2006.
192
193        :param epsilon: Convergence criterion, maximum squared difference allowed between mean updates to stop iterations (float)
194        :type epsilon: float
195        :param eta: parameter for fractional EP updates.
196        :type eta: float64
197        :param delta: damping EP updates factor.
198        :type delta: float64
199        :param always_reset: setting to always reset the approximation at the beginning of every inference call.
200        :type always_reest: boolean
201        :max_iters: int
202        :ep_mode: string. It can be "nested" (EP is run every time the Hyperparameters change) or "alternated" (It runs EP at the beginning and then optimize the Hyperparameters).
203        :parallel_updates: boolean. If true, updates of the parameters of the sites in parallel
204        :loading: boolean. If True, prevents the EP parameters to change. Hack used when loading a serialized model
205        """
206        super(EPBase, self).__init__()
207
208        self.always_reset = always_reset
209        self.epsilon, self.eta, self.delta, self.max_iters = epsilon, eta, delta, max_iters
210        self.ep_mode = ep_mode
211        self.parallel_updates = parallel_updates
212        #FIXME: Hack for serialiation. If True, prevents the EP parameters to change when loading a serialized model
213        self.loading = loading
214        self.reset()
215
216    def reset(self):
217        self.ga_approx_old = None
218        self._ep_approximation = None
219
220    def on_optimization_start(self):
221        self._ep_approximation = None
222
223    def on_optimization_end(self):
224        # TODO: update approximation in the end as well? Maybe even with a switch?
225        pass
226
227    def _stop_criteria(self, ga_approx):
228        tau_diff = np.mean(np.square(ga_approx.tau-self.ga_approx_old.tau))
229        v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v))
230        return ((tau_diff < self.epsilon) and (v_diff < self.epsilon))
231
232    def __setstate__(self, state):
233        super(EPBase, self).__setstate__(state[0])
234        self.epsilon, self.eta, self.delta = state[1]
235        self.reset()
236
237    def __getstate__(self):
238        return [super(EPBase, self).__getstate__() , [self.epsilon, self.eta, self.delta]]
239
240    def _save_to_input_dict(self):
241        input_dict = super(EPBase, self)._save_to_input_dict()
242        input_dict["epsilon"]=self.epsilon
243        input_dict["eta"]=self.eta
244        input_dict["delta"]=self.delta
245        input_dict["always_reset"]=self.always_reset
246        input_dict["max_iters"]=self.max_iters
247        input_dict["ep_mode"]=self.ep_mode
248        input_dict["parallel_updates"]=self.parallel_updates
249        input_dict["loading"]=True
250        return input_dict
251
252class EP(EPBase, ExactGaussianInference):
253    def inference(self, kern, X, likelihood, Y, mean_function=None, Y_metadata=None, precision=None, K=None):
254        if self.always_reset and not self.loading:
255            self.reset()
256
257        num_data, output_dim = Y.shape
258        assert output_dim == 1, "ep in 1D only (for now!)"
259
260        if mean_function is None:
261            mean_prior = np.zeros(X.shape[0])
262        else:
263            mean_prior = mean_function.f(X).flatten()
264
265        if K is None:
266            K = kern.K(X)
267
268        if self.ep_mode=="nested" and not self.loading:
269            #Force EP at each step of the optimization
270            self._ep_approximation = None
271            post_params, ga_approx, cav_params, log_Z_tilde = self._ep_approximation = self.expectation_propagation(mean_prior, K, Y, likelihood, Y_metadata)
272        elif self.ep_mode=="alternated" or self.loading:
273            if getattr(self, '_ep_approximation', None) is None:
274                #if we don't yet have the results of runnign EP, run EP and store the computed factors in self._ep_approximation
275                post_params, ga_approx, cav_params, log_Z_tilde = self._ep_approximation = self.expectation_propagation(mean_prior, K, Y, likelihood, Y_metadata)
276            else:
277                #if we've already run EP, just use the existing approximation stored in self._ep_approximation
278                post_params, ga_approx, cav_params, log_Z_tilde = self._ep_approximation
279        else:
280            raise ValueError("ep_mode value not valid")
281
282        self.loading = False
283
284        return self._inference(Y, mean_prior, K, ga_approx, cav_params, likelihood, Y_metadata=Y_metadata,  Z_tilde=log_Z_tilde)
285
286    def expectation_propagation(self, mean_prior, K, Y, likelihood, Y_metadata):
287
288        num_data, data_dim = Y.shape
289        assert data_dim == 1, "This EP methods only works for 1D outputs"
290
291        # Makes computing the sign quicker if we work with numpy arrays rather
292        # than ObsArrays
293        Y = Y.values.copy()
294
295        #Initial values - Marginal moments, cavity params, gaussian approximation params and posterior params
296        marg_moments = marginalMoments(num_data)
297        cav_params = cavityParams(num_data)
298        ga_approx, post_params = self._init_approximations(mean_prior, K, num_data)
299
300        #Approximation
301        stop = False
302        iterations = 0
303        while not stop and (iterations < self.max_iters):
304            self._local_updates(num_data, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata)
305
306            #(re) compute Sigma and mu using full Cholesky decompy
307            post_params = posteriorParams._recompute(mean_prior, K, ga_approx)
308
309            #monitor convergence
310            if iterations > 0:
311                stop = self._stop_criteria(ga_approx)
312            self.ga_approx_old = gaussianApproximation(ga_approx.v.copy(), ga_approx.tau.copy())
313            iterations += 1
314
315        log_Z_tilde = self._log_Z_tilde(marg_moments, ga_approx, cav_params)
316
317        return (post_params, ga_approx, cav_params, log_Z_tilde)
318
319    def _init_approximations(self, mean_prior, K, num_data):
320        #initial values - Gaussian factors
321        #Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
322        if self.ga_approx_old is None:
323            v_tilde, tau_tilde = np.zeros((2, num_data))
324            ga_approx = gaussianApproximation(v_tilde, tau_tilde)
325            Sigma = K.copy()
326            diag.add(Sigma, 1e-7)
327            mu = mean_prior
328            post_params = posteriorParams(mu, Sigma)
329        else:
330            assert self.ga_approx_old.v.size == num_data, "data size mis-match: did you change the data? try resetting!"
331            ga_approx = gaussianApproximation(self.ga_approx_old.v, self.ga_approx_old.tau)
332            post_params = posteriorParams._recompute(mean_prior, K, ga_approx)
333            diag.add(post_params.Sigma, 1e-7)
334            # TODO: Check the log-marginal under both conditions and choose the best one
335        return (ga_approx, post_params)
336
337    def _local_updates(self, num_data, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata, update_order=None):
338            if update_order is None:
339                update_order = np.random.permutation(num_data)
340            for i in update_order:
341                #Cavity distribution parameters
342                cav_params._update_i(self.eta, ga_approx, post_params, i)
343
344                if Y_metadata is not None:
345                    # Pick out the relavent metadata for Yi
346                    Y_metadata_i = {}
347                    for key in Y_metadata.keys():
348                        Y_metadata_i[key] = Y_metadata[key][i, :]
349                else:
350                    Y_metadata_i = None
351                #Marginal moments
352                marg_moments.Z_hat[i], marg_moments.mu_hat[i], marg_moments.sigma2_hat[i] = likelihood.moments_match_ep(Y[i], cav_params.tau[i], cav_params.v[i], Y_metadata_i=Y_metadata_i)
353
354                #Site parameters update
355                delta_tau, delta_v = ga_approx._update_i(self.eta, self.delta, post_params, marg_moments, i)
356
357                if self.parallel_updates == False:
358                    post_params._update_rank1(delta_tau, delta_v, ga_approx, i)
359
360    def _log_Z_tilde(self, marg_moments, ga_approx, cav_params):
361        # Z_tilde after removing the terms that can lead to infinite terms due to tau_tilde close to zero.
362        # This terms cancel with the coreresponding terms in the marginal loglikelihood
363        return np.sum((
364                np.log(marg_moments.Z_hat)
365                + 0.5*np.log(2*np.pi) + 0.5*np.log(1+ga_approx.tau/cav_params.tau)
366                - 0.5 * ((ga_approx.v)**2 * 1./(cav_params.tau + ga_approx.tau))
367                + 0.5*(cav_params.v * ( ( (ga_approx.tau/cav_params.tau) * cav_params.v - 2.0 * ga_approx.v ) * 1./(cav_params.tau + ga_approx.tau)))
368                ))
369
370    def _ep_marginal(self, mean_prior, K, ga_approx, Z_tilde):
371        post_params = posteriorParams._recompute(mean_prior, K, ga_approx)
372        # Gaussian log marginal excluding terms that can go to infinity due to arbitrarily small tau_tilde.
373        # These terms cancel out with the terms excluded from Z_tilde
374        B_logdet = np.sum(2.0*np.log(np.diag(post_params.L)))
375        S_mean_prior = ga_approx.tau * mean_prior
376        v_centered = ga_approx.v - S_mean_prior
377        log_marginal =  0.5*(
378                        -len(ga_approx.tau) * log_2_pi - B_logdet
379                        + np.sum(v_centered * np.dot(post_params.Sigma, v_centered))
380                        - np.dot(mean_prior, (S_mean_prior - 2*ga_approx.v))
381                        )
382        log_marginal += Z_tilde
383
384        return log_marginal, post_params
385
386    def _inference(self, Y, mean_prior, K, ga_approx, cav_params, likelihood, Z_tilde, Y_metadata=None):
387        log_marginal, post_params = self._ep_marginal(mean_prior, K, ga_approx, Z_tilde)
388
389        tau_tilde_root = np.sqrt(ga_approx.tau)
390        Sroot_tilde_K = tau_tilde_root[:,None] * K
391
392
393        aux_alpha , _ = dpotrs(post_params.L, tau_tilde_root * (np.dot(K, ga_approx.v) +  mean_prior), lower=1)
394        alpha = (ga_approx.v - tau_tilde_root * aux_alpha)[:,None] #(K + Sigma^(\tilde))^(-1) (/mu^(/tilde) -  /mu_p)
395
396        LWi, _ = dtrtrs(post_params.L, np.diag(tau_tilde_root), lower=1)
397        Wi = np.dot(LWi.T,LWi)
398        symmetrify(Wi) #(K + Sigma^(\tilde))^(-1)
399
400        dL_dK = 0.5 * (tdot(alpha) - Wi)
401        dL_dthetaL = likelihood.ep_gradients(Y, cav_params.tau, cav_params.v, np.diag(dL_dK), Y_metadata=Y_metadata, quad_mode='gh')
402        return Posterior(woodbury_inv=Wi, woodbury_vector=alpha, K=K), log_marginal, {'dL_dK':dL_dK, 'dL_dthetaL':dL_dthetaL, 'dL_dm':alpha}
403
404    def to_dict(self):
405        """
406        Convert the object into a json serializable dictionary.
407
408        Note: It uses the private method _save_to_input_dict of the parent.
409
410        :return dict: json serializable dictionary containing the needed information to instantiate the object
411        """
412
413        input_dict = super(EP, self)._save_to_input_dict()
414        input_dict["class"] = "GPy.inference.latent_function_inference.expectation_propagation.EP"
415        if self.ga_approx_old is not  None:
416            input_dict["ga_approx_old"] = self.ga_approx_old.to_dict()
417        if self._ep_approximation is not  None:
418            input_dict["_ep_approximation"] = {}
419            input_dict["_ep_approximation"]["post_params"] = self._ep_approximation[0].to_dict()
420            input_dict["_ep_approximation"]["ga_approx"] = self._ep_approximation[1].to_dict()
421            input_dict["_ep_approximation"]["cav_params"] = self._ep_approximation[2].to_dict()
422            input_dict["_ep_approximation"]["log_Z_tilde"] = self._ep_approximation[3].tolist()
423
424        return input_dict
425
426    @staticmethod
427    def _build_from_input_dict(inference_class, input_dict):
428        ga_approx_old = input_dict.pop('ga_approx_old', None)
429        if ga_approx_old is not None:
430            ga_approx_old = gaussianApproximation.from_dict(ga_approx_old)
431        _ep_approximation_dict = input_dict.pop('_ep_approximation', None)
432        _ep_approximation = []
433        if _ep_approximation is not None:
434            _ep_approximation.append(posteriorParams.from_dict(_ep_approximation_dict["post_params"]))
435            _ep_approximation.append(gaussianApproximation.from_dict(_ep_approximation_dict["ga_approx"]))
436            _ep_approximation.append(cavityParams.from_dict(_ep_approximation_dict["cav_params"]))
437            _ep_approximation.append(np.array(_ep_approximation_dict["log_Z_tilde"]))
438        ee = EP(**input_dict)
439        ee.ga_approx_old = ga_approx_old
440        ee._ep_approximation = _ep_approximation
441        return ee
442
443class EPDTC(EPBase, VarDTC):
444    def inference(self, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, Lm=None, dL_dKmm=None, psi0=None, psi1=None, psi2=None):
445        if self.always_reset and not self.loading:
446            self.reset()
447
448        num_data, output_dim = Y.shape
449        assert output_dim == 1, "ep in 1D only (for now!)"
450
451        if Lm is None:
452            Kmm = kern.K(Z)
453            Lm = jitchol(Kmm)
454
455        if psi1 is None:
456            try:
457                Kmn = kern.K(Z, X)
458            except TypeError:
459                Kmn = kern.psi1(Z, X).T
460        else:
461            Kmn = psi1.T
462
463        if self.ep_mode=="nested" and not self.loading:
464            #Force EP at each step of the optimization
465            self._ep_approximation = None
466            post_params, ga_approx, log_Z_tilde = self._ep_approximation = self.expectation_propagation(Kmm, Kmn, Y, likelihood, Y_metadata)
467        elif self.ep_mode=="alternated" or self.loading:
468            if getattr(self, '_ep_approximation', None) is None:
469                #if we don't yet have the results of runnign EP, run EP and store the computed factors in self._ep_approximation
470                post_params, ga_approx, log_Z_tilde = self._ep_approximation = self.expectation_propagation(Kmm, Kmn, Y, likelihood, Y_metadata)
471            else:
472                #if we've already run EP, just use the existing approximation stored in self._ep_approximation
473                post_params, ga_approx, log_Z_tilde = self._ep_approximation
474        else:
475            raise ValueError("ep_mode value not valid")
476
477        self.loading = False
478
479        mu_tilde = ga_approx.v / ga_approx.tau.astype(float)
480
481        return super(EPDTC, self).inference(kern, X, Z, likelihood, ObsAr(mu_tilde[:,None]),
482                                            mean_function=mean_function,
483                                            Y_metadata=Y_metadata,
484                                            precision=ga_approx.tau,
485                                            Lm=Lm, dL_dKmm=dL_dKmm,
486                                            psi0=psi0, psi1=psi1, psi2=psi2, Z_tilde=log_Z_tilde)
487
488    def expectation_propagation(self, Kmm, Kmn, Y, likelihood, Y_metadata):
489
490        num_data, output_dim = Y.shape
491        assert output_dim == 1, "This EP methods only works for 1D outputs"
492
493        # Makes computing the sign quicker if we work with numpy arrays rather
494        # than ObsArrays
495        Y = Y.values.copy()
496
497        #Initial values - Marginal moments, cavity params, gaussian approximation params and posterior params
498        marg_moments = marginalMoments(num_data)
499        cav_params = cavityParams(num_data)
500        ga_approx, post_params, LLT0, LLT = self._init_approximations(Kmm, Kmn, num_data)
501
502        #Approximation
503        stop = False
504        iterations = 0
505        while not stop and (iterations < self.max_iters):
506            self._local_updates(num_data, LLT0, LLT, Kmn, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata)
507            #(re) compute Sigma, Sigma_diag and mu using full Cholesky decompy
508            post_params, LLT = posteriorParamsDTC._recompute(LLT0, Kmn, ga_approx)
509            post_params.Sigma_diag = np.maximum(post_params.Sigma_diag, np.finfo(float).eps)
510
511            #monitor convergence
512            if iterations > 0:
513                stop = self._stop_criteria(ga_approx)
514            self.ga_approx_old = gaussianApproximation(ga_approx.v.copy(), ga_approx.tau.copy())
515            iterations += 1
516
517        log_Z_tilde = self._log_Z_tilde(marg_moments, ga_approx, cav_params)
518
519        return post_params, ga_approx, log_Z_tilde
520
521    def _log_Z_tilde(self, marg_moments, ga_approx, cav_params):
522        mu_tilde = ga_approx.v/ga_approx.tau
523        mu_cav = cav_params.v/cav_params.tau
524        sigma2_sigma2tilde = 1./cav_params.tau + 1./ga_approx.tau
525
526        return np.sum((np.log(marg_moments.Z_hat) + 0.5*np.log(2*np.pi) + 0.5*np.log(sigma2_sigma2tilde)
527                         + 0.5*((mu_cav - mu_tilde)**2) / (sigma2_sigma2tilde)))
528
529    def _init_approximations(self, Kmm, Kmn, num_data):
530        #initial values - Gaussian factors
531        #Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
532        LLT0 = Kmm.copy()
533        Lm = jitchol(LLT0) #K_m = L_m L_m^\top
534        Vm,info = dtrtrs(Lm, Kmn,lower=1)
535        # Lmi = dtrtri(Lm)
536        # Kmmi = np.dot(Lmi.T,Lmi)
537        # KmmiKmn = np.dot(Kmmi,Kmn)
538        # Qnn_diag = np.sum(Kmn*KmmiKmn,-2)
539        Qnn_diag = np.sum(Vm*Vm,-2) #diag(Knm Kmm^(-1) Kmn)
540        #diag.add(LLT0, 1e-8)
541        if self.ga_approx_old is None:
542            #Initial values - Posterior distribution parameters: q(f|X,Y) = N(f|mu,Sigma)
543            LLT = LLT0.copy() #Sigma = K.copy()
544            mu = np.zeros(num_data)
545            Sigma_diag = Qnn_diag.copy() + 1e-8
546            v_tilde, tau_tilde = np.zeros((2, num_data))
547            ga_approx = gaussianApproximation(v_tilde, tau_tilde)
548            post_params = posteriorParamsDTC(mu, Sigma_diag)
549
550        else:
551            assert self.ga_approx_old.v.size == num_data, "data size mis-match: did you change the data? try resetting!"
552            ga_approx = gaussianApproximation(self.ga_approx_old.v, self.ga_approx_old.tau)
553            post_params, LLT = posteriorParamsDTC._recompute(LLT0, Kmn, ga_approx)
554            post_params.Sigma_diag += 1e-8
555
556            # TODO: Check the log-marginal under both conditions and choose the best one
557
558        return (ga_approx, post_params, LLT0, LLT)
559
560    def _local_updates(self, num_data, LLT0, LLT, Kmn, cav_params, post_params, marg_moments, ga_approx, likelihood, Y, Y_metadata, update_order=None):
561        if update_order is None:
562            update_order = np.random.permutation(num_data)
563        for i in update_order:
564
565            #Cavity distribution parameters
566            cav_params._update_i(self.eta, ga_approx, post_params, i)
567
568
569            if Y_metadata is not None:
570                # Pick out the relavent metadata for Yi
571                Y_metadata_i = {}
572                for key in Y_metadata.keys():
573                    Y_metadata_i[key] = Y_metadata[key][i, :]
574            else:
575                Y_metadata_i = None
576
577            #Marginal moments
578            marg_moments.Z_hat[i], marg_moments.mu_hat[i], marg_moments.sigma2_hat[i] = likelihood.moments_match_ep(Y[i], cav_params.tau[i], cav_params.v[i], Y_metadata_i=Y_metadata_i)
579            #Site parameters update
580            delta_tau, delta_v = ga_approx._update_i(self.eta, self.delta, post_params, marg_moments, i)
581
582            #Posterior distribution parameters update
583            if self.parallel_updates == False:
584                post_params._update_rank1(LLT, Kmn, delta_v, delta_tau, i)
585
586
587    def to_dict(self):
588        """
589        Convert the object into a json serializable dictionary.
590
591        Note: It uses the private method _save_to_input_dict of the parent.
592
593        :return dict: json serializable dictionary containing the needed information to instantiate the object
594        """
595
596        input_dict = super(EPDTC, self)._save_to_input_dict()
597        input_dict["class"] = "GPy.inference.latent_function_inference.expectation_propagation.EPDTC"
598        if self.ga_approx_old is not  None:
599            input_dict["ga_approx_old"] = self.ga_approx_old.to_dict()
600        if self._ep_approximation is not  None:
601            input_dict["_ep_approximation"] = {}
602            input_dict["_ep_approximation"]["post_params"] = self._ep_approximation[0].to_dict()
603            input_dict["_ep_approximation"]["ga_approx"] = self._ep_approximation[1].to_dict()
604            input_dict["_ep_approximation"]["log_Z_tilde"] = self._ep_approximation[2]
605
606        return input_dict
607
608    @staticmethod
609    def _build_from_input_dict(inference_class, input_dict):
610        ga_approx_old = input_dict.pop('ga_approx_old', None)
611        if ga_approx_old is not None:
612            ga_approx_old = gaussianApproximation.from_dict(ga_approx_old)
613        _ep_approximation_dict = input_dict.pop('_ep_approximation', None)
614        _ep_approximation = []
615        if _ep_approximation is not None:
616            _ep_approximation.append(posteriorParamsDTC.from_dict(_ep_approximation_dict["post_params"]))
617            _ep_approximation.append(gaussianApproximation.from_dict(_ep_approximation_dict["ga_approx"]))
618            _ep_approximation.append(_ep_approximation_dict["log_Z_tilde"])
619        ee = EPDTC(**input_dict)
620        ee.ga_approx_old = ga_approx_old
621        ee._ep_approximation = _ep_approximation
622        return ee
623