1 //
2 // predModule.cpp: implementation of predictor module using Eigen
3 //
4 // Copyright (C) 2011-2013 Douglas Bates, Martin Maechler, Ben Bolker and Steve Walker
5 //
6 // This file is part of lme4.
7 
8 #include "predModule.h"
9 
10 namespace lme4 {
11     using    Rcpp::as;
12 
13     using     std::invalid_argument;
14     using     std::runtime_error;
15 
16     using   Eigen::ArrayXd;
17 
18     typedef Eigen::Map<MatrixXd>  MMat;
19     typedef Eigen::Map<VectorXd>  MVec;
20     typedef Eigen::Map<VectorXi>  MiVec;
21 
merPredD(SEXP X,SEXP Lambdat,SEXP LamtUt,SEXP Lind,SEXP RZX,SEXP Ut,SEXP Utr,SEXP V,SEXP VtV,SEXP Vtr,SEXP Xwts,SEXP Zt,SEXP beta0,SEXP delb,SEXP delu,SEXP theta,SEXP u0)22     merPredD::merPredD(SEXP X, SEXP Lambdat, SEXP LamtUt, SEXP Lind,
23                        SEXP RZX, SEXP Ut, SEXP Utr, SEXP V, SEXP VtV,
24                        SEXP Vtr, SEXP Xwts, SEXP Zt, SEXP beta0,
25                        SEXP delb, SEXP delu, SEXP theta, SEXP u0)
26         : d_X(       as<MMat>(X)),
27           d_RZX(     as<MMat>(RZX)),
28           d_V(       as<MMat>(V)),
29           d_VtV(     as<MMat>(VtV)),
30           d_Zt(      as<MSpMatrixd>(Zt)),
31           d_Ut(      as<MSpMatrixd>(Ut)),
32           d_LamtUt(  as<MSpMatrixd>(LamtUt)),
33           d_Lambdat( as<MSpMatrixd>(Lambdat)),
34           d_theta(   as<MVec>(theta)),
35           d_Vtr(     as<MVec>(Vtr)),
36           d_Utr(     as<MVec>(Utr)),
37           d_Xwts(    as<MVec>(Xwts)),
38           d_beta0(   as<MVec>(beta0)),
39           d_delb(    as<MVec>(delb)),
40           d_delu(    as<MVec>(delu)),
41           d_u0(      as<MVec>(u0)),
42           d_Lind(    as<MiVec>(Lind)),
43           d_N(       d_X.rows()),
44           d_p(       d_X.cols()),
45           d_q(       d_Zt.rows()),
46           d_RX(      d_p)
47     {                           // Check consistency of dimensions
48         if (d_N != d_Zt.cols())
49             throw invalid_argument("Z dimension mismatch");
50         if (d_Lind.size() != d_Lambdat.nonZeros())
51             throw invalid_argument("size of Lind does not match nonzeros in Lambda");
52         // checking of the range of Lind is now done in R code for reference class
53                                 // initialize beta0, u0, delb, delu and VtV
54         d_VtV.setZero().selfadjointView<Eigen::Upper>().rankUpdate(d_V.adjoint());
55         d_RX.compute(d_VtV);    // ensure d_RX is initialized even in the 0-column X case
56 
57         setTheta(d_theta);          // starting values into Lambda
58         d_L.cholmod().final_ll = 1; // force an LL' decomposition
59         updateLamtUt();
60         d_L.analyzePattern(d_LamtUt * d_LamtUt.transpose()); // perform symbolic analysis
61         if (d_L.info() != Eigen::Success)
62             throw runtime_error("CholeskyDecomposition.analyzePattern failed");
63     }
64 
updateLamtUt()65     void merPredD::updateLamtUt() {
66         // This complicated code bypasses problems caused by Eigen's
67         // sparse/sparse matrix multiplication pruning zeros.  The
68         // Cholesky decomposition croaks if the structure of d_LamtUt changes.
69         MVec(d_LamtUt.valuePtr(), d_LamtUt.nonZeros()).setZero();
70         for (Index j = 0; j < d_Ut.outerSize(); ++j) {
71             for(MSpMatrixd::InnerIterator rhsIt(d_Ut, j); rhsIt; ++rhsIt) {
72                 Scalar                        y(rhsIt.value());
73                 Index                         k(rhsIt.index());
74                 MSpMatrixd::InnerIterator prdIt(d_LamtUt, j);
75                 for (MSpMatrixd::InnerIterator lhsIt(d_Lambdat, k); lhsIt; ++lhsIt) {
76                     Index                     i = lhsIt.index();
77                     while (prdIt && prdIt.index() != i) ++prdIt;
78                     if (!prdIt) throw runtime_error("logic error in updateLamtUt");
79                     prdIt.valueRef()           += lhsIt.value() * y;
80                 }
81             }
82         }
83     }
84 
b(const double & f) const85     VectorXd merPredD::b(const double& f) const {return d_Lambdat.adjoint() * u(f);}
86 
beta(const double & f) const87     VectorXd merPredD::beta(const double& f) const {return d_beta0 + f * d_delb;}
88 
linPred(const double & f) const89     VectorXd merPredD::linPred(const double& f) const {
90         return d_X * beta(f) + d_Zt.adjoint() * b(f);
91     }
92 
condVar(const Rcpp::Environment & rho) const93     Rcpp::List merPredD::condVar(const Rcpp::Environment& rho) const {
94         const Rcpp::List ll(as<Rcpp::List>(rho["flist"])), trmlst(as<Rcpp::List>(rho["terms"]));
95         const int nf(ll.size());
96         const MiVec nl(as<MiVec>(rho["nlevs"])),
97             nct(as<MiVec>(rho["nctot"])), off(as<MiVec>(rho["offsets"]));
98         // ll : flist
99         // trmlst : terms : list with one element per factor, indicating corresponding term
100         // nf : : number of unique factors
101         // nl : nlevs : number of levels for each unique factor
102         // nct : nctot : total number of components per factor
103         // off : offsets : points to where each term starts
104         Rcpp::List ans(nf);
105         ans.names() = clone(as<Rcpp::CharacterVector>(ll.names()));
106         const SpMatrixd d_Lambda(d_Lambdat.adjoint());
107         for (int i = 0; i < nf; i++) {
108             int         ncti(nct[i]), nli(nl[i]);
109             Rcpp::NumericVector ansi(ncti * ncti * nli);
110             ansi.attr("dim") = Rcpp::IntegerVector::create(ncti, ncti, nli);
111             ans[i] = ansi;
112             const MiVec trms(as<MiVec>(trmlst(i)));
113             // ncti : total number of components in factor i
114             // nli : number of levels in factor i
115             // ansi : array in which to store condVar's for factor i
116             // trms : pointers to terms corresponding to factor i
117             if (trms.size() == 1) { // simple case
118                 int offset = off[trms[0] - 1];
119                 for (int j = 0; j < nli; ++j) {
120                     MatrixXd LvT(d_Lambdat.innerVectors(offset + j * ncti, ncti));
121                     MatrixXd Lv(LvT.adjoint());
122                     d_L.solveInPlace(LvT, CHOLMOD_A);
123                     MatrixXd rr(Lv * LvT);
124                     std::copy(rr.data(), rr.data() + rr.size(), &ansi[j * ncti * ncti]);
125                 }
126             } else {
127                 throw std::runtime_error("multiple terms per factor not yet written");
128             }
129         }
130         return ans;
131     }
132 
u(const double & f) const133     VectorXd merPredD::u(const double& f) const {return d_u0 + f * d_delu;}
134 
sqrL(const double & f) const135     merPredD::Scalar merPredD::sqrL(const double& f) const {return u(f).squaredNorm();}
136 
updateL()137     void merPredD::updateL() {
138         updateLamtUt();
139         // More complicated code to handle the case of zeros in
140         // potentially nonzero positions.  The factorize_p method is
141         // for a SparseMatrix<double>, not a MappedSparseMatrix<double>.
142         SpMatrixd  m(d_LamtUt.rows(), d_LamtUt.cols());
143         m.resizeNonZeros(d_LamtUt.nonZeros());
144         std::copy(d_LamtUt.valuePtr(),
145                   d_LamtUt.valuePtr() + d_LamtUt.nonZeros(),
146                   m.valuePtr());
147         std::copy(d_LamtUt.innerIndexPtr(),
148                   d_LamtUt.innerIndexPtr() + d_LamtUt.nonZeros(),
149                   m.innerIndexPtr());
150         std::copy(d_LamtUt.outerIndexPtr(),
151                   d_LamtUt.outerIndexPtr() + d_LamtUt.cols() + 1,
152                   m.outerIndexPtr());
153         d_L.factorize_p(m, Eigen::ArrayXi(), 1.);
154         d_ldL2 = ::M_chm_factor_ldetL2(d_L.factor());
155     }
156 
setTheta(const VectorXd & theta)157     void merPredD::setTheta(const VectorXd& theta) {
158 
159         if (theta.size() != d_theta.size()) {
160             Rcpp::Rcout << "(" << theta.size() << "!=" <<
161                 d_theta.size() << ")" << std::endl;
162             // char errstr[100];
163             // sprintf(errstr,"theta size mismatch (%d != %d)",
164             //      theta.size(),d_theta.size());
165             throw invalid_argument("theta size mismatch");
166         }
167         // update theta
168         std::copy(theta.data(), theta.data() + theta.size(),
169                   d_theta.data());
170         // update Lambdat
171         int    *lipt = d_Lind.data();
172         double *LamX = d_Lambdat.valuePtr(), *thpt = d_theta.data();
173         for (int i = 0; i < d_Lind.size(); ++i) {
174             LamX[i] = thpt[lipt[i] - 1];
175         }
176     }
177 
setZt(const VectorXd & ZtNonZero)178     void merPredD::setZt(const VectorXd& ZtNonZero) {
179         double *ZtX = d_Zt.valuePtr(); // where the nonzero values of Zt live
180         std::copy(ZtNonZero.data(), ZtNonZero.data() + ZtNonZero.size(), ZtX);
181     }
182 
183 
solve()184     merPredD::Scalar merPredD::solve() {
185         d_delu          = d_Utr - d_u0;
186         d_L.solveInPlace(d_delu, CHOLMOD_P);
187         d_L.solveInPlace(d_delu, CHOLMOD_L);    // d_delu now contains cu
188         d_CcNumer       = d_delu.squaredNorm(); // numerator of convergence criterion
189 
190         d_delb          = d_RX.matrixL().solve(d_Vtr - d_RZX.adjoint() * d_delu);
191         d_CcNumer      += d_delb.squaredNorm(); // increment CcNumer
192         d_RX.matrixU().solveInPlace(d_delb);
193 
194         d_delu         -= d_RZX * d_delb;
195         d_L.solveInPlace(d_delu, CHOLMOD_Lt);
196         d_L.solveInPlace(d_delu, CHOLMOD_Pt);
197         return d_CcNumer;
198     }
199 
solveU()200     merPredD::Scalar merPredD::solveU() {
201         d_delb.setZero(); // in calculation of linPred delb should be zero after solveU
202         d_delu          = d_Utr - d_u0;
203         d_L.solveInPlace(d_delu, CHOLMOD_P);
204         d_L.solveInPlace(d_delu, CHOLMOD_L);    // d_delu now contains cu
205         d_CcNumer       = d_delu.squaredNorm(); // numerator of convergence criterion
206         d_L.solveInPlace(d_delu, CHOLMOD_Lt);
207         d_L.solveInPlace(d_delu, CHOLMOD_Pt);
208   return d_CcNumer;
209     }
210 
updateXwts(const ArrayXd & sqrtXwt)211     void merPredD::updateXwts(const ArrayXd& sqrtXwt) {
212         if (d_Xwts.size() != sqrtXwt.size())
213             throw invalid_argument("updateXwts: dimension mismatch");
214         std::copy(sqrtXwt.data(), sqrtXwt.data() + sqrtXwt.size(), d_Xwts.data());
215         if (sqrtXwt.size() == d_V.rows()) { // W is diagonal
216             d_V              = d_Xwts.asDiagonal() * d_X;
217             for (int j = 0; j < d_N; ++j)
218                 for (MSpMatrixd::InnerIterator Utj(d_Ut, j), Ztj(d_Zt, j);
219                      Utj && Ztj; ++Utj, ++Ztj)
220                     Utj.valueRef() = Ztj.value() * d_Xwts.data()[j];
221         } else {
222             SpMatrixd      W(d_V.rows(), sqrtXwt.size());
223             const double *pt = sqrtXwt.data();
224             W.reserve(sqrtXwt.size());
225             for (Index j = 0; j < W.cols(); ++j, ++pt) {
226                 W.startVec(j);
227                 W.insertBack(j % d_V.rows(), j) = *pt;
228             }
229             W.finalize();
230             d_V              = W * d_X;
231             SpMatrixd      Ut(d_Zt * W.adjoint());
232             if (Ut.cols() != d_Ut.cols())
233                 throw std::runtime_error("Size mismatch in updateXwts");
234 
235             // More complex code to handle the pruning of zeros
236             MVec(d_Ut.valuePtr(), d_Ut.nonZeros()).setZero();
237             for (int j = 0; j < d_Ut.outerSize(); ++j) {
238                 MSpMatrixd::InnerIterator lhsIt(d_Ut, j);
239                 for (SpMatrixd::InnerIterator  rhsIt(Ut, j); rhsIt; ++rhsIt, ++lhsIt) {
240                     Index                         k(rhsIt.index());
241                     while (lhsIt && lhsIt.index() != k) ++lhsIt;
242                     if (lhsIt.index() != k)
243                         throw std::runtime_error("Pattern mismatch in updateXwts");
244                     lhsIt.valueRef() = rhsIt.value();
245                 }
246             }
247         }
248         d_VtV.setZero().selfadjointView<Eigen::Upper>().rankUpdate(d_V.adjoint());
249         updateL();
250     }
251 
updateDecomp()252     void merPredD::updateDecomp() {
253         updateDecomp(NULL);
254     }
255 
256     // using a point so as to detect NULL
updateDecomp(const MatrixXd * xPenalty)257     void merPredD::updateDecomp(const MatrixXd* xPenalty) {  // update L, RZX and RX
258         int debug=0;
259 
260         if (debug) Rcpp::Rcout << "start updateDecomp" << std::endl;
261         updateL();
262         if (debug) {
263             Rcpp::Rcout << "updateDecomp 2: dimensions (RZX, LamtUt,V)" <<
264                 d_RZX.cols() << " " << d_RZX.rows() << " " <<
265                 d_LamtUt.cols() << " " << d_LamtUt.rows() << " " <<
266                 d_V.cols() << " " << d_V.rows() << " " <<
267                 std::endl;
268         }
269         if (d_LamtUt.cols() != d_V.rows()) {
270             ::Rf_warning("dimension mismatch in updateDecomp()");
271             // Rcpp::Rcout << "WARNING: dimension mismatch in updateDecomp(): " <<
272             // " LamtUt=" << d_LamtUt.rows() << "x" << d_LamtUt.cols() <<
273             // "; V=" << d_V.rows() << "x" << d_V.cols() << " " <<
274             // std::endl;
275         }
276         d_RZX         = d_LamtUt * d_V;
277         if (debug) Rcpp::Rcout << "updateDecomp 3" << std::endl;
278         if (d_p > 0) {
279             d_L.solveInPlace(d_RZX, CHOLMOD_P);
280             d_L.solveInPlace(d_RZX, CHOLMOD_L);
281             if (debug) Rcpp::Rcout << "updateDecomp 4" << std::endl;
282             MatrixXd      VtVdown(d_VtV);
283 
284             if (xPenalty == NULL)
285                 d_RX.compute(VtVdown.selfadjointView<Eigen::Upper>().rankUpdate(d_RZX.adjoint(), -1));
286             else {
287                 d_RX.compute(VtVdown.selfadjointView<Eigen::Upper>().rankUpdate(d_RZX.adjoint(), -1).rankUpdate(*xPenalty, 1));
288             }
289             if (debug) Rcpp::Rcout << "updateDecomp 5" << std::endl;
290             if (d_RX.info() != Eigen::Success)
291                 ::Rf_error("Downdated VtV is not positive definite");
292             d_ldRX2         = 2. * d_RX.matrixLLT().diagonal().array().abs().log().sum();
293             if (debug) Rcpp::Rcout << "updateDecomp 6" << std::endl;
294         }
295     }
296 
updateRes(const VectorXd & wtres)297     void merPredD::updateRes(const VectorXd& wtres) {
298         if (d_V.rows() != wtres.size())
299             throw invalid_argument("updateRes: dimension mismatch");
300         d_Vtr           = d_V.adjoint() * wtres;
301         d_Utr           = d_LamtUt * wtres;
302     }
303 
installPars(const Scalar & f)304     void merPredD::installPars(const Scalar& f) {
305         d_u0 = u(f);
306         d_beta0 = beta(f);
307         d_delb.setZero();
308         d_delu.setZero();
309     }
310 
setBeta0(const VectorXd & nBeta)311     void merPredD::setBeta0(const VectorXd& nBeta) {
312         if (nBeta.size() != d_p) throw invalid_argument("setBeta0: dimension mismatch");
313         std::copy(nBeta.data(), nBeta.data() + d_p, d_beta0.data());
314     }
315 
setDelb(const VectorXd & newDelb)316     void merPredD::setDelb(const VectorXd& newDelb) {
317         if (newDelb.size() != d_p)
318             throw invalid_argument("setDelb: dimension mismatch");
319         std::copy(newDelb.data(), newDelb.data() + d_p, d_delb.data());
320     }
321 
setDelu(const VectorXd & newDelu)322     void merPredD::setDelu(const VectorXd& newDelu) {
323         if (newDelu.size() != d_q)
324             throw invalid_argument("setDelu: dimension mismatch");
325         std::copy(newDelu.data(), newDelu.data() + d_q, d_delu.data());
326     }
327 
setU0(const VectorXd & newU0)328     void merPredD::setU0(const VectorXd& newU0) {
329         if (newU0.size() != d_q) throw invalid_argument("setU0: dimension mismatch");
330         std::copy(newU0.data(), newU0.data() + d_q, d_u0.data());
331     }
332 
333     template <typename T>
334     struct Norm_Rand : std::unary_function<T, T> {
operator ()lme4::Norm_Rand335         const T operator()(const T& x) const {return ::norm_rand();}
336     };
337 
Random_Normal(int size,double sigma)338     inline static VectorXd Random_Normal(int size, double sigma) {
339         return ArrayXd(size).unaryExpr(Norm_Rand<double>()) * sigma;
340     }
341 
MCMC_beta_u(const Scalar & sigma)342     void merPredD::MCMC_beta_u(const Scalar& sigma) {
343         VectorXd del2(d_RX.matrixU().solve(Random_Normal(d_p, sigma)));
344         d_delb += del2;
345         VectorXd del1(Random_Normal(d_q, sigma) - d_RZX * del2);
346         d_L.solveInPlace(del1, CHOLMOD_Lt);
347         d_delu += del1;
348     }
349 
Pvec() const350     VectorXi merPredD::Pvec() const {
351         int*                 ppt((int*)d_L.factor()->Perm);
352         VectorXi             ans(d_q);
353         std::copy(ppt, ppt + d_q, ans.data());
354         return ans;
355     }
356 
RX() const357     MatrixXd merPredD::RX() const {
358         return d_RX.matrixU();
359     }
360 
RXi() const361     MatrixXd merPredD::RXi() const { // inverse RX
362         return d_RX.matrixU().solve(MatrixXd::Identity(d_p,d_p));
363     }
364 
unsc() const365     MatrixXd merPredD::unsc() const { // unscaled var-cov mat of FE
366         // R translation: tcrossprod(RXi)
367         return MatrixXd(MatrixXd(d_p, d_p).setZero().
368                         selfadjointView<Eigen::Lower>().
369                         rankUpdate(RXi()));
370     }
371 
RXdiag() const372     VectorXd merPredD::RXdiag() const {
373         return d_RX.matrixLLT().diagonal();
374     }
375 }
376