1 // 2 // predModule.cpp: implementation of predictor module using Eigen 3 // 4 // Copyright (C) 2011-2013 Douglas Bates, Martin Maechler, Ben Bolker and Steve Walker 5 // 6 // This file is part of lme4. 7 8 #include "predModule.h" 9 10 namespace lme4 { 11 using Rcpp::as; 12 13 using std::invalid_argument; 14 using std::runtime_error; 15 16 using Eigen::ArrayXd; 17 18 typedef Eigen::Map<MatrixXd> MMat; 19 typedef Eigen::Map<VectorXd> MVec; 20 typedef Eigen::Map<VectorXi> MiVec; 21 merPredD(SEXP X,SEXP Lambdat,SEXP LamtUt,SEXP Lind,SEXP RZX,SEXP Ut,SEXP Utr,SEXP V,SEXP VtV,SEXP Vtr,SEXP Xwts,SEXP Zt,SEXP beta0,SEXP delb,SEXP delu,SEXP theta,SEXP u0)22 merPredD::merPredD(SEXP X, SEXP Lambdat, SEXP LamtUt, SEXP Lind, 23 SEXP RZX, SEXP Ut, SEXP Utr, SEXP V, SEXP VtV, 24 SEXP Vtr, SEXP Xwts, SEXP Zt, SEXP beta0, 25 SEXP delb, SEXP delu, SEXP theta, SEXP u0) 26 : d_X( as<MMat>(X)), 27 d_RZX( as<MMat>(RZX)), 28 d_V( as<MMat>(V)), 29 d_VtV( as<MMat>(VtV)), 30 d_Zt( as<MSpMatrixd>(Zt)), 31 d_Ut( as<MSpMatrixd>(Ut)), 32 d_LamtUt( as<MSpMatrixd>(LamtUt)), 33 d_Lambdat( as<MSpMatrixd>(Lambdat)), 34 d_theta( as<MVec>(theta)), 35 d_Vtr( as<MVec>(Vtr)), 36 d_Utr( as<MVec>(Utr)), 37 d_Xwts( as<MVec>(Xwts)), 38 d_beta0( as<MVec>(beta0)), 39 d_delb( as<MVec>(delb)), 40 d_delu( as<MVec>(delu)), 41 d_u0( as<MVec>(u0)), 42 d_Lind( as<MiVec>(Lind)), 43 d_N( d_X.rows()), 44 d_p( d_X.cols()), 45 d_q( d_Zt.rows()), 46 d_RX( d_p) 47 { // Check consistency of dimensions 48 if (d_N != d_Zt.cols()) 49 throw invalid_argument("Z dimension mismatch"); 50 if (d_Lind.size() != d_Lambdat.nonZeros()) 51 throw invalid_argument("size of Lind does not match nonzeros in Lambda"); 52 // checking of the range of Lind is now done in R code for reference class 53 // initialize beta0, u0, delb, delu and VtV 54 d_VtV.setZero().selfadjointView<Eigen::Upper>().rankUpdate(d_V.adjoint()); 55 d_RX.compute(d_VtV); // ensure d_RX is initialized even in the 0-column X case 56 57 setTheta(d_theta); // starting values into Lambda 58 d_L.cholmod().final_ll = 1; // force an LL' decomposition 59 updateLamtUt(); 60 d_L.analyzePattern(d_LamtUt * d_LamtUt.transpose()); // perform symbolic analysis 61 if (d_L.info() != Eigen::Success) 62 throw runtime_error("CholeskyDecomposition.analyzePattern failed"); 63 } 64 updateLamtUt()65 void merPredD::updateLamtUt() { 66 // This complicated code bypasses problems caused by Eigen's 67 // sparse/sparse matrix multiplication pruning zeros. The 68 // Cholesky decomposition croaks if the structure of d_LamtUt changes. 69 MVec(d_LamtUt.valuePtr(), d_LamtUt.nonZeros()).setZero(); 70 for (Index j = 0; j < d_Ut.outerSize(); ++j) { 71 for(MSpMatrixd::InnerIterator rhsIt(d_Ut, j); rhsIt; ++rhsIt) { 72 Scalar y(rhsIt.value()); 73 Index k(rhsIt.index()); 74 MSpMatrixd::InnerIterator prdIt(d_LamtUt, j); 75 for (MSpMatrixd::InnerIterator lhsIt(d_Lambdat, k); lhsIt; ++lhsIt) { 76 Index i = lhsIt.index(); 77 while (prdIt && prdIt.index() != i) ++prdIt; 78 if (!prdIt) throw runtime_error("logic error in updateLamtUt"); 79 prdIt.valueRef() += lhsIt.value() * y; 80 } 81 } 82 } 83 } 84 b(const double & f) const85 VectorXd merPredD::b(const double& f) const {return d_Lambdat.adjoint() * u(f);} 86 beta(const double & f) const87 VectorXd merPredD::beta(const double& f) const {return d_beta0 + f * d_delb;} 88 linPred(const double & f) const89 VectorXd merPredD::linPred(const double& f) const { 90 return d_X * beta(f) + d_Zt.adjoint() * b(f); 91 } 92 condVar(const Rcpp::Environment & rho) const93 Rcpp::List merPredD::condVar(const Rcpp::Environment& rho) const { 94 const Rcpp::List ll(as<Rcpp::List>(rho["flist"])), trmlst(as<Rcpp::List>(rho["terms"])); 95 const int nf(ll.size()); 96 const MiVec nl(as<MiVec>(rho["nlevs"])), 97 nct(as<MiVec>(rho["nctot"])), off(as<MiVec>(rho["offsets"])); 98 // ll : flist 99 // trmlst : terms : list with one element per factor, indicating corresponding term 100 // nf : : number of unique factors 101 // nl : nlevs : number of levels for each unique factor 102 // nct : nctot : total number of components per factor 103 // off : offsets : points to where each term starts 104 Rcpp::List ans(nf); 105 ans.names() = clone(as<Rcpp::CharacterVector>(ll.names())); 106 const SpMatrixd d_Lambda(d_Lambdat.adjoint()); 107 for (int i = 0; i < nf; i++) { 108 int ncti(nct[i]), nli(nl[i]); 109 Rcpp::NumericVector ansi(ncti * ncti * nli); 110 ansi.attr("dim") = Rcpp::IntegerVector::create(ncti, ncti, nli); 111 ans[i] = ansi; 112 const MiVec trms(as<MiVec>(trmlst(i))); 113 // ncti : total number of components in factor i 114 // nli : number of levels in factor i 115 // ansi : array in which to store condVar's for factor i 116 // trms : pointers to terms corresponding to factor i 117 if (trms.size() == 1) { // simple case 118 int offset = off[trms[0] - 1]; 119 for (int j = 0; j < nli; ++j) { 120 MatrixXd LvT(d_Lambdat.innerVectors(offset + j * ncti, ncti)); 121 MatrixXd Lv(LvT.adjoint()); 122 d_L.solveInPlace(LvT, CHOLMOD_A); 123 MatrixXd rr(Lv * LvT); 124 std::copy(rr.data(), rr.data() + rr.size(), &ansi[j * ncti * ncti]); 125 } 126 } else { 127 throw std::runtime_error("multiple terms per factor not yet written"); 128 } 129 } 130 return ans; 131 } 132 u(const double & f) const133 VectorXd merPredD::u(const double& f) const {return d_u0 + f * d_delu;} 134 sqrL(const double & f) const135 merPredD::Scalar merPredD::sqrL(const double& f) const {return u(f).squaredNorm();} 136 updateL()137 void merPredD::updateL() { 138 updateLamtUt(); 139 // More complicated code to handle the case of zeros in 140 // potentially nonzero positions. The factorize_p method is 141 // for a SparseMatrix<double>, not a MappedSparseMatrix<double>. 142 SpMatrixd m(d_LamtUt.rows(), d_LamtUt.cols()); 143 m.resizeNonZeros(d_LamtUt.nonZeros()); 144 std::copy(d_LamtUt.valuePtr(), 145 d_LamtUt.valuePtr() + d_LamtUt.nonZeros(), 146 m.valuePtr()); 147 std::copy(d_LamtUt.innerIndexPtr(), 148 d_LamtUt.innerIndexPtr() + d_LamtUt.nonZeros(), 149 m.innerIndexPtr()); 150 std::copy(d_LamtUt.outerIndexPtr(), 151 d_LamtUt.outerIndexPtr() + d_LamtUt.cols() + 1, 152 m.outerIndexPtr()); 153 d_L.factorize_p(m, Eigen::ArrayXi(), 1.); 154 d_ldL2 = ::M_chm_factor_ldetL2(d_L.factor()); 155 } 156 setTheta(const VectorXd & theta)157 void merPredD::setTheta(const VectorXd& theta) { 158 159 if (theta.size() != d_theta.size()) { 160 Rcpp::Rcout << "(" << theta.size() << "!=" << 161 d_theta.size() << ")" << std::endl; 162 // char errstr[100]; 163 // sprintf(errstr,"theta size mismatch (%d != %d)", 164 // theta.size(),d_theta.size()); 165 throw invalid_argument("theta size mismatch"); 166 } 167 // update theta 168 std::copy(theta.data(), theta.data() + theta.size(), 169 d_theta.data()); 170 // update Lambdat 171 int *lipt = d_Lind.data(); 172 double *LamX = d_Lambdat.valuePtr(), *thpt = d_theta.data(); 173 for (int i = 0; i < d_Lind.size(); ++i) { 174 LamX[i] = thpt[lipt[i] - 1]; 175 } 176 } 177 setZt(const VectorXd & ZtNonZero)178 void merPredD::setZt(const VectorXd& ZtNonZero) { 179 double *ZtX = d_Zt.valuePtr(); // where the nonzero values of Zt live 180 std::copy(ZtNonZero.data(), ZtNonZero.data() + ZtNonZero.size(), ZtX); 181 } 182 183 solve()184 merPredD::Scalar merPredD::solve() { 185 d_delu = d_Utr - d_u0; 186 d_L.solveInPlace(d_delu, CHOLMOD_P); 187 d_L.solveInPlace(d_delu, CHOLMOD_L); // d_delu now contains cu 188 d_CcNumer = d_delu.squaredNorm(); // numerator of convergence criterion 189 190 d_delb = d_RX.matrixL().solve(d_Vtr - d_RZX.adjoint() * d_delu); 191 d_CcNumer += d_delb.squaredNorm(); // increment CcNumer 192 d_RX.matrixU().solveInPlace(d_delb); 193 194 d_delu -= d_RZX * d_delb; 195 d_L.solveInPlace(d_delu, CHOLMOD_Lt); 196 d_L.solveInPlace(d_delu, CHOLMOD_Pt); 197 return d_CcNumer; 198 } 199 solveU()200 merPredD::Scalar merPredD::solveU() { 201 d_delb.setZero(); // in calculation of linPred delb should be zero after solveU 202 d_delu = d_Utr - d_u0; 203 d_L.solveInPlace(d_delu, CHOLMOD_P); 204 d_L.solveInPlace(d_delu, CHOLMOD_L); // d_delu now contains cu 205 d_CcNumer = d_delu.squaredNorm(); // numerator of convergence criterion 206 d_L.solveInPlace(d_delu, CHOLMOD_Lt); 207 d_L.solveInPlace(d_delu, CHOLMOD_Pt); 208 return d_CcNumer; 209 } 210 updateXwts(const ArrayXd & sqrtXwt)211 void merPredD::updateXwts(const ArrayXd& sqrtXwt) { 212 if (d_Xwts.size() != sqrtXwt.size()) 213 throw invalid_argument("updateXwts: dimension mismatch"); 214 std::copy(sqrtXwt.data(), sqrtXwt.data() + sqrtXwt.size(), d_Xwts.data()); 215 if (sqrtXwt.size() == d_V.rows()) { // W is diagonal 216 d_V = d_Xwts.asDiagonal() * d_X; 217 for (int j = 0; j < d_N; ++j) 218 for (MSpMatrixd::InnerIterator Utj(d_Ut, j), Ztj(d_Zt, j); 219 Utj && Ztj; ++Utj, ++Ztj) 220 Utj.valueRef() = Ztj.value() * d_Xwts.data()[j]; 221 } else { 222 SpMatrixd W(d_V.rows(), sqrtXwt.size()); 223 const double *pt = sqrtXwt.data(); 224 W.reserve(sqrtXwt.size()); 225 for (Index j = 0; j < W.cols(); ++j, ++pt) { 226 W.startVec(j); 227 W.insertBack(j % d_V.rows(), j) = *pt; 228 } 229 W.finalize(); 230 d_V = W * d_X; 231 SpMatrixd Ut(d_Zt * W.adjoint()); 232 if (Ut.cols() != d_Ut.cols()) 233 throw std::runtime_error("Size mismatch in updateXwts"); 234 235 // More complex code to handle the pruning of zeros 236 MVec(d_Ut.valuePtr(), d_Ut.nonZeros()).setZero(); 237 for (int j = 0; j < d_Ut.outerSize(); ++j) { 238 MSpMatrixd::InnerIterator lhsIt(d_Ut, j); 239 for (SpMatrixd::InnerIterator rhsIt(Ut, j); rhsIt; ++rhsIt, ++lhsIt) { 240 Index k(rhsIt.index()); 241 while (lhsIt && lhsIt.index() != k) ++lhsIt; 242 if (lhsIt.index() != k) 243 throw std::runtime_error("Pattern mismatch in updateXwts"); 244 lhsIt.valueRef() = rhsIt.value(); 245 } 246 } 247 } 248 d_VtV.setZero().selfadjointView<Eigen::Upper>().rankUpdate(d_V.adjoint()); 249 updateL(); 250 } 251 updateDecomp()252 void merPredD::updateDecomp() { 253 updateDecomp(NULL); 254 } 255 256 // using a point so as to detect NULL updateDecomp(const MatrixXd * xPenalty)257 void merPredD::updateDecomp(const MatrixXd* xPenalty) { // update L, RZX and RX 258 int debug=0; 259 260 if (debug) Rcpp::Rcout << "start updateDecomp" << std::endl; 261 updateL(); 262 if (debug) { 263 Rcpp::Rcout << "updateDecomp 2: dimensions (RZX, LamtUt,V)" << 264 d_RZX.cols() << " " << d_RZX.rows() << " " << 265 d_LamtUt.cols() << " " << d_LamtUt.rows() << " " << 266 d_V.cols() << " " << d_V.rows() << " " << 267 std::endl; 268 } 269 if (d_LamtUt.cols() != d_V.rows()) { 270 ::Rf_warning("dimension mismatch in updateDecomp()"); 271 // Rcpp::Rcout << "WARNING: dimension mismatch in updateDecomp(): " << 272 // " LamtUt=" << d_LamtUt.rows() << "x" << d_LamtUt.cols() << 273 // "; V=" << d_V.rows() << "x" << d_V.cols() << " " << 274 // std::endl; 275 } 276 d_RZX = d_LamtUt * d_V; 277 if (debug) Rcpp::Rcout << "updateDecomp 3" << std::endl; 278 if (d_p > 0) { 279 d_L.solveInPlace(d_RZX, CHOLMOD_P); 280 d_L.solveInPlace(d_RZX, CHOLMOD_L); 281 if (debug) Rcpp::Rcout << "updateDecomp 4" << std::endl; 282 MatrixXd VtVdown(d_VtV); 283 284 if (xPenalty == NULL) 285 d_RX.compute(VtVdown.selfadjointView<Eigen::Upper>().rankUpdate(d_RZX.adjoint(), -1)); 286 else { 287 d_RX.compute(VtVdown.selfadjointView<Eigen::Upper>().rankUpdate(d_RZX.adjoint(), -1).rankUpdate(*xPenalty, 1)); 288 } 289 if (debug) Rcpp::Rcout << "updateDecomp 5" << std::endl; 290 if (d_RX.info() != Eigen::Success) 291 ::Rf_error("Downdated VtV is not positive definite"); 292 d_ldRX2 = 2. * d_RX.matrixLLT().diagonal().array().abs().log().sum(); 293 if (debug) Rcpp::Rcout << "updateDecomp 6" << std::endl; 294 } 295 } 296 updateRes(const VectorXd & wtres)297 void merPredD::updateRes(const VectorXd& wtres) { 298 if (d_V.rows() != wtres.size()) 299 throw invalid_argument("updateRes: dimension mismatch"); 300 d_Vtr = d_V.adjoint() * wtres; 301 d_Utr = d_LamtUt * wtres; 302 } 303 installPars(const Scalar & f)304 void merPredD::installPars(const Scalar& f) { 305 d_u0 = u(f); 306 d_beta0 = beta(f); 307 d_delb.setZero(); 308 d_delu.setZero(); 309 } 310 setBeta0(const VectorXd & nBeta)311 void merPredD::setBeta0(const VectorXd& nBeta) { 312 if (nBeta.size() != d_p) throw invalid_argument("setBeta0: dimension mismatch"); 313 std::copy(nBeta.data(), nBeta.data() + d_p, d_beta0.data()); 314 } 315 setDelb(const VectorXd & newDelb)316 void merPredD::setDelb(const VectorXd& newDelb) { 317 if (newDelb.size() != d_p) 318 throw invalid_argument("setDelb: dimension mismatch"); 319 std::copy(newDelb.data(), newDelb.data() + d_p, d_delb.data()); 320 } 321 setDelu(const VectorXd & newDelu)322 void merPredD::setDelu(const VectorXd& newDelu) { 323 if (newDelu.size() != d_q) 324 throw invalid_argument("setDelu: dimension mismatch"); 325 std::copy(newDelu.data(), newDelu.data() + d_q, d_delu.data()); 326 } 327 setU0(const VectorXd & newU0)328 void merPredD::setU0(const VectorXd& newU0) { 329 if (newU0.size() != d_q) throw invalid_argument("setU0: dimension mismatch"); 330 std::copy(newU0.data(), newU0.data() + d_q, d_u0.data()); 331 } 332 333 template <typename T> 334 struct Norm_Rand : std::unary_function<T, T> { operator ()lme4::Norm_Rand335 const T operator()(const T& x) const {return ::norm_rand();} 336 }; 337 Random_Normal(int size,double sigma)338 inline static VectorXd Random_Normal(int size, double sigma) { 339 return ArrayXd(size).unaryExpr(Norm_Rand<double>()) * sigma; 340 } 341 MCMC_beta_u(const Scalar & sigma)342 void merPredD::MCMC_beta_u(const Scalar& sigma) { 343 VectorXd del2(d_RX.matrixU().solve(Random_Normal(d_p, sigma))); 344 d_delb += del2; 345 VectorXd del1(Random_Normal(d_q, sigma) - d_RZX * del2); 346 d_L.solveInPlace(del1, CHOLMOD_Lt); 347 d_delu += del1; 348 } 349 Pvec() const350 VectorXi merPredD::Pvec() const { 351 int* ppt((int*)d_L.factor()->Perm); 352 VectorXi ans(d_q); 353 std::copy(ppt, ppt + d_q, ans.data()); 354 return ans; 355 } 356 RX() const357 MatrixXd merPredD::RX() const { 358 return d_RX.matrixU(); 359 } 360 RXi() const361 MatrixXd merPredD::RXi() const { // inverse RX 362 return d_RX.matrixU().solve(MatrixXd::Identity(d_p,d_p)); 363 } 364 unsc() const365 MatrixXd merPredD::unsc() const { // unscaled var-cov mat of FE 366 // R translation: tcrossprod(RXi) 367 return MatrixXd(MatrixXd(d_p, d_p).setZero(). 368 selfadjointView<Eigen::Lower>(). 369 rankUpdate(RXi())); 370 } 371 RXdiag() const372 VectorXd merPredD::RXdiag() const { 373 return d_RX.matrixLLT().diagonal(); 374 } 375 } 376