1# Copyright (c) 2017, GPy authors (see AUTHORS.txt).
2# Licensed under the BSD 3-clause license (see LICENSE.txt)
3
4from GPy.util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri,pdinv, dpotri
5from GPy.util import diag
6from GPy.core.parameterization.variational import VariationalPosterior
7import numpy as np
8from GPy.inference.latent_function_inference import LatentFunctionInference
9from GPy.inference.latent_function_inference.posterior import Posterior
10from .vardtc_svi_multiout import PosteriorMultioutput
11log_2_pi = np.log(2*np.pi)
12
13
14class VarDTC_SVI_Multiout_Miss(LatentFunctionInference):
15    """
16    The VarDTC inference method for Multi-output GP regression with missing data (GPy.models.GPMultioutRegressionMD)
17    """
18    const_jitter = 1e-6
19
20    def get_trYYT(self, Y):
21        return np.sum(np.square(Y))
22
23    def get_YYTfactor(self, Y):
24        N, D = Y.shape
25        if (N>=D):
26            return Y.view(np.ndarray)
27        else:
28            return jitchol(tdot(Y))
29
30    def gatherPsiStat(self, kern, X, Z, uncertain_inputs):
31
32        if uncertain_inputs:
33            psi0 = kern.psi0(Z, X)
34            psi1 = kern.psi1(Z, X)
35            psi2 = kern.psi2n(Z, X)
36        else:
37            psi0 = kern.Kdiag(X)
38            psi1 = kern.K(X, Z)
39            psi2 = psi1[:,:,None]*psi1[:,None,:]
40
41        return psi0, psi1, psi2
42
43    def _init_grad_dict(self, N, D, Mr, Mc):
44        grad_dict = {
45            'dL_dthetaL': np.zeros(D),
46            'dL_dqU_mean': np.zeros((Mc,Mr)),
47            'dL_dqU_var_c':np.zeros((Mc,Mc)),
48            'dL_dqU_var_r':np.zeros((Mr,Mr)),
49            'dL_dKuu_c': np.zeros((Mc,Mc)),
50            'dL_dKuu_r': np.zeros((Mr,Mr)),
51            'dL_dpsi0_c': np.zeros(N),
52            'dL_dpsi1_c': np.zeros((N,Mc)),
53            'dL_dpsi2_c': np.zeros((N,Mc,Mc)),
54            'dL_dpsi0_r': np.zeros(D),
55            'dL_dpsi1_r': np.zeros((D,Mr)),
56            'dL_dpsi2_r': np.zeros((D,Mr,Mr)),
57        }
58        return grad_dict
59
60    def inference_d(self, d, beta, Y, indexD, grad_dict, mid_res, uncertain_inputs_r, uncertain_inputs_c, Mr, Mc):
61
62        idx_d = indexD==d
63        Y = Y[idx_d]
64        N, D = Y.shape[0], 1
65        beta = beta[d]
66
67        psi0_r, psi1_r, psi2_r = mid_res['psi0_r'], mid_res['psi1_r'], mid_res['psi2_r']
68        psi0_c, psi1_c, psi2_c = mid_res['psi0_c'], mid_res['psi1_c'], mid_res['psi2_c']
69        psi0_r, psi1_r, psi2_r = psi0_r[d], psi1_r[d:d+1], psi2_r[d]
70        psi0_c, psi1_c, psi2_c = psi0_c[idx_d].sum(), psi1_c[idx_d], psi2_c[idx_d].sum(0)
71
72        Lr = mid_res['Lr']
73        Lc = mid_res['Lc']
74        LcInvMLrInvT = mid_res['LcInvMLrInvT']
75        LcInvScLcInvT = mid_res['LcInvScLcInvT']
76        LrInvSrLrInvT = mid_res['LrInvSrLrInvT']
77
78
79        LcInvPsi2_cLcInvT = backsub_both_sides(Lc, psi2_c,'right')
80        LrInvPsi2_rLrInvT = backsub_both_sides(Lr, psi2_r,'right')
81        LcInvPsi1_cT = dtrtrs(Lc, psi1_c.T)[0]
82        LrInvPsi1_rT = dtrtrs(Lr, psi1_r.T)[0]
83
84        tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT = (LrInvPsi2_rLrInvT*LrInvSrLrInvT).sum()
85        tr_LcInvPsi2_cLcInvT_LcInvScLcInvT = (LcInvPsi2_cLcInvT*LcInvScLcInvT).sum()
86        tr_LrInvPsi2_rLrInvT = np.trace(LrInvPsi2_rLrInvT)
87        tr_LcInvPsi2_cLcInvT = np.trace(LcInvPsi2_cLcInvT)
88
89        logL_A = - np.square(Y).sum() \
90               - (LcInvMLrInvT.T.dot(LcInvPsi2_cLcInvT).dot(LcInvMLrInvT)*LrInvPsi2_rLrInvT).sum() \
91               -  tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT* tr_LcInvPsi2_cLcInvT_LcInvScLcInvT \
92               + 2 * (Y * LcInvPsi1_cT.T.dot(LcInvMLrInvT).dot(LrInvPsi1_rT)).sum() - psi0_c * psi0_r \
93               + tr_LrInvPsi2_rLrInvT * tr_LcInvPsi2_cLcInvT
94
95        logL = -N*D/2.*(np.log(2.*np.pi)-np.log(beta)) + beta/2.* logL_A
96
97        # ======= Gradients =====
98
99        tmp =  beta* LcInvPsi2_cLcInvT.dot(LcInvMLrInvT).dot(LrInvPsi2_rLrInvT).dot(LcInvMLrInvT.T) \
100             + beta* tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT * LcInvPsi2_cLcInvT.dot(LcInvScLcInvT) \
101             - beta* LcInvMLrInvT.dot(LrInvPsi1_rT).dot(Y.T).dot(LcInvPsi1_cT.T) \
102             - beta/2. * tr_LrInvPsi2_rLrInvT* LcInvPsi2_cLcInvT
103
104        dL_dKuu_c = backsub_both_sides(Lc, tmp, 'left')
105        dL_dKuu_c += dL_dKuu_c.T
106        dL_dKuu_c *= 0.5
107
108        tmp =  beta* LcInvMLrInvT.T.dot(LcInvPsi2_cLcInvT).dot(LcInvMLrInvT).dot(LrInvPsi2_rLrInvT) \
109             + beta* tr_LcInvPsi2_cLcInvT_LcInvScLcInvT * LrInvPsi2_rLrInvT.dot(LrInvSrLrInvT) \
110             - beta* LrInvPsi1_rT.dot(Y.T).dot(LcInvPsi1_cT.T).dot(LcInvMLrInvT) \
111             - beta/2. * tr_LcInvPsi2_cLcInvT * LrInvPsi2_rLrInvT
112
113        dL_dKuu_r = backsub_both_sides(Lr, tmp, 'left')
114        dL_dKuu_r += dL_dKuu_r.T
115        dL_dKuu_r *= 0.5
116
117        #======================================================================
118        # Compute dL_dthetaL
119        #======================================================================
120
121        dL_dthetaL = -D*N*beta/2. - logL_A*beta*beta/2.
122
123        #======================================================================
124        # Compute dL_dqU
125        #======================================================================
126
127        tmp = -beta * LcInvPsi2_cLcInvT.dot(LcInvMLrInvT).dot(LrInvPsi2_rLrInvT)\
128              + beta* LcInvPsi1_cT.dot(Y).dot(LrInvPsi1_rT.T)
129
130        dL_dqU_mean = dtrtrs(Lc, dtrtrs(Lr, tmp.T, trans=1)[0].T, trans=1)[0]
131
132        tmp = -beta/2.*tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT * LcInvPsi2_cLcInvT
133        dL_dqU_var_c = backsub_both_sides(Lc, tmp, 'left')
134
135        tmp = -beta/2.*tr_LcInvPsi2_cLcInvT_LcInvScLcInvT * LrInvPsi2_rLrInvT
136        dL_dqU_var_r = backsub_both_sides(Lr, tmp, 'left')
137
138        #======================================================================
139        # Compute dL_dpsi
140        #======================================================================
141
142        dL_dpsi0_r = - psi0_c * beta/2. * np.ones((D,))
143        dL_dpsi0_c = - psi0_r * beta/2. * np.ones((N,))
144
145        dL_dpsi1_c = beta * dtrtrs(Lc, (Y.dot(LrInvPsi1_rT.T).dot(LcInvMLrInvT.T)).T, trans=1)[0].T
146        dL_dpsi1_r = beta * dtrtrs(Lr, (Y.T.dot(LcInvPsi1_cT.T).dot(LcInvMLrInvT)).T, trans=1)[0].T
147
148        tmp = beta/2.*(-LcInvMLrInvT.dot(LrInvPsi2_rLrInvT).dot(LcInvMLrInvT.T) - tr_LrInvPsi2_rLrInvT_LrInvSrLrInvT * LcInvScLcInvT
149              +tr_LrInvPsi2_rLrInvT *np.eye(Mc))
150        dL_dpsi2_c = backsub_both_sides(Lc, tmp, 'left')
151        tmp = beta/2.*(-LcInvMLrInvT.T.dot(LcInvPsi2_cLcInvT).dot(LcInvMLrInvT) - tr_LcInvPsi2_cLcInvT_LcInvScLcInvT * LrInvSrLrInvT
152              +tr_LcInvPsi2_cLcInvT *np.eye(Mr))
153        dL_dpsi2_r = backsub_both_sides(Lr, tmp, 'left')
154
155        grad_dict['dL_dthetaL'][d:d+1] = dL_dthetaL
156        grad_dict['dL_dqU_mean'] += dL_dqU_mean
157        grad_dict['dL_dqU_var_c'] += dL_dqU_var_c
158        grad_dict['dL_dqU_var_r'] += dL_dqU_var_r
159        grad_dict['dL_dKuu_c'] += dL_dKuu_c
160        grad_dict['dL_dKuu_r'] += dL_dKuu_r
161
162        # if not uncertain_inputs_r:
163        #     dL_dpsi1_r += (dL_dpsi2_r * psi1_r[:,:,None]).sum(1) + (dL_dpsi2_r * psi1_r[:,None,:]).sum(2)
164        # if not uncertain_inputs_c:
165        #     dL_dpsi1_c += (dL_dpsi2_c * psi1_c[:,:,None]).sum(1) + (dL_dpsi2_c * psi1_c[:,None,:]).sum(2)
166
167        if not uncertain_inputs_r:
168            dL_dpsi1_r += psi1_r.dot(dL_dpsi2_r+dL_dpsi2_r.T)
169        if not uncertain_inputs_c:
170            dL_dpsi1_c += psi1_c.dot(dL_dpsi2_c+dL_dpsi2_c.T)
171
172        grad_dict['dL_dpsi0_c'][idx_d] += dL_dpsi0_c
173        grad_dict['dL_dpsi1_c'][idx_d] += dL_dpsi1_c
174        grad_dict['dL_dpsi2_c'][idx_d] += dL_dpsi2_c
175
176        grad_dict['dL_dpsi0_r'][d:d+1] += dL_dpsi0_r
177        grad_dict['dL_dpsi1_r'][d:d+1] += dL_dpsi1_r
178        grad_dict['dL_dpsi2_r'][d] += dL_dpsi2_r
179
180
181        return logL
182
183
184    def inference(self, kern_r, kern_c, Xr, Xc, Zr, Zc, likelihood, Y, qU_mean ,qU_var_r, qU_var_c, indexD, output_dim):
185        """
186        The SVI-VarDTC inference
187        """
188
189        N, D, Mr, Mc, Qr, Qc = Y.shape[0], output_dim,Zr.shape[0], Zc.shape[0], Zr.shape[1], Zc.shape[1]
190
191        uncertain_inputs_r = isinstance(Xr, VariationalPosterior)
192        uncertain_inputs_c = isinstance(Xc, VariationalPosterior)
193        uncertain_outputs = isinstance(Y, VariationalPosterior)
194
195        grad_dict = self._init_grad_dict(N,D,Mr,Mc)
196
197        beta = 1./likelihood.variance
198        if len(beta)==1:
199            beta = np.zeros(D)+beta
200
201        psi0_r, psi1_r, psi2_r = self.gatherPsiStat(kern_r, Xr, Zr, uncertain_inputs_r)
202        psi0_c, psi1_c, psi2_c = self.gatherPsiStat(kern_c, Xc, Zc, uncertain_inputs_c)
203
204        #======================================================================
205        # Compute Common Components
206        #======================================================================
207
208        Kuu_r = kern_r.K(Zr).copy()
209        diag.add(Kuu_r, self.const_jitter)
210        Lr = jitchol(Kuu_r)
211
212        Kuu_c = kern_c.K(Zc).copy()
213        diag.add(Kuu_c, self.const_jitter)
214        Lc = jitchol(Kuu_c)
215
216        mu, Sr, Sc = qU_mean, qU_var_r, qU_var_c
217        LSr = jitchol(Sr)
218        LSc = jitchol(Sc)
219
220        LcInvMLrInvT = dtrtrs(Lc,dtrtrs(Lr,mu.T)[0].T)[0]
221        LcInvLSc = dtrtrs(Lc, LSc)[0]
222        LrInvLSr = dtrtrs(Lr, LSr)[0]
223        LcInvScLcInvT = tdot(LcInvLSc)
224        LrInvSrLrInvT = tdot(LrInvLSr)
225        tr_LrInvSrLrInvT = np.square(LrInvLSr).sum()
226        tr_LcInvScLcInvT = np.square(LcInvLSc).sum()
227
228        mid_res = {
229            'psi0_r': psi0_r,
230            'psi1_r': psi1_r,
231            'psi2_r': psi2_r,
232            'psi0_c': psi0_c,
233            'psi1_c': psi1_c,
234            'psi2_c': psi2_c,
235            'Lr':Lr,
236            'Lc':Lc,
237            'LcInvMLrInvT': LcInvMLrInvT,
238            'LcInvScLcInvT': LcInvScLcInvT,
239            'LrInvSrLrInvT': LrInvSrLrInvT,
240        }
241
242        #======================================================================
243        # Compute log-likelihood
244        #======================================================================
245
246        logL = 0.
247        for d in range(D):
248            logL += self.inference_d(d, beta, Y, indexD, grad_dict, mid_res, uncertain_inputs_r, uncertain_inputs_c, Mr, Mc)
249
250        logL += -Mc * (np.log(np.diag(Lr)).sum()-np.log(np.diag(LSr)).sum())  -Mr * (np.log(np.diag(Lc)).sum()-np.log(np.diag(LSc)).sum()) \
251               - np.square(LcInvMLrInvT).sum()/2. - tr_LrInvSrLrInvT * tr_LcInvScLcInvT/2. + Mr*Mc/2.
252
253        #======================================================================
254        # Compute dL_dKuu
255        #======================================================================
256
257        tmp =  tdot(LcInvMLrInvT)/2. + tr_LrInvSrLrInvT/2. * LcInvScLcInvT - Mr/2.*np.eye(Mc)
258
259        dL_dKuu_c = backsub_both_sides(Lc, tmp, 'left')
260        dL_dKuu_c += dL_dKuu_c.T
261        dL_dKuu_c *= 0.5
262
263        tmp =  tdot(LcInvMLrInvT.T)/2. + tr_LcInvScLcInvT/2. * LrInvSrLrInvT - Mc/2.*np.eye(Mr)
264
265        dL_dKuu_r = backsub_both_sides(Lr, tmp, 'left')
266        dL_dKuu_r += dL_dKuu_r.T
267        dL_dKuu_r *= 0.5
268
269        #======================================================================
270        # Compute dL_dqU
271        #======================================================================
272
273        tmp = - LcInvMLrInvT
274        dL_dqU_mean = dtrtrs(Lc, dtrtrs(Lr, tmp.T, trans=1)[0].T, trans=1)[0]
275
276        LScInv = dtrtri(LSc)
277        tmp = -tr_LrInvSrLrInvT/2.*np.eye(Mc)
278        dL_dqU_var_c = backsub_both_sides(Lc, tmp, 'left') + tdot(LScInv.T) * Mr/2.
279
280        LSrInv = dtrtri(LSr)
281        tmp =  -tr_LcInvScLcInvT/2.*np.eye(Mr)
282        dL_dqU_var_r = backsub_both_sides(Lr, tmp, 'left') + tdot(LSrInv.T) * Mc/2.
283
284        #======================================================================
285        # Compute the Posterior distribution of inducing points p(u|Y)
286        #======================================================================
287
288        post = PosteriorMultioutput(LcInvMLrInvT=LcInvMLrInvT, LcInvScLcInvT=LcInvScLcInvT,
289                LrInvSrLrInvT=LrInvSrLrInvT, Lr=Lr, Lc=Lc, kern_r=kern_r, Xr=Xr, Zr=Zr)
290
291        #======================================================================
292        # Compute dL_dpsi
293        #======================================================================
294
295        grad_dict['dL_dqU_mean'] += dL_dqU_mean
296        grad_dict['dL_dqU_var_c'] += dL_dqU_var_c
297        grad_dict['dL_dqU_var_r'] += dL_dqU_var_r
298        grad_dict['dL_dKuu_c'] += dL_dKuu_c
299        grad_dict['dL_dKuu_r'] += dL_dKuu_r
300
301        if not uncertain_inputs_c:
302            grad_dict['dL_dKdiag_c'] = grad_dict['dL_dpsi0_c']
303            grad_dict['dL_dKfu_c'] = grad_dict['dL_dpsi1_c']
304
305        if not uncertain_inputs_r:
306            grad_dict['dL_dKdiag_r'] = grad_dict['dL_dpsi0_r']
307            grad_dict['dL_dKfu_r'] = grad_dict['dL_dpsi1_r']
308
309        return post, logL, grad_dict
310