1 
2 #include "em.h"
3 #include <RcppArmadillo.h>
4 
5 
6 using namespace Rcpp ;
7 
emcore(SEXP xs,SEXP AMr1s,SEXP os,SEXP ms,SEXP ivec,SEXP thetas,SEXP tols,SEXP emburns,SEXP p2ss,SEXP empris,SEXP autos,SEXP alls,SEXP prs)8 SEXP emcore(SEXP xs, SEXP AMr1s, SEXP os, SEXP ms, SEXP ivec, SEXP thetas, SEXP tols, SEXP emburns, SEXP p2ss, SEXP empris, SEXP autos, SEXP alls, SEXP prs){
9 
10   //, SEXP p2ss, SEXP prs, SEXP empris, SEXP fends, SEXP alls, SEXP autos, SEXP emburns//
11 
12   NumericMatrix xr(xs);
13   NumericMatrix thetar(thetas);
14   NumericVector tol(tols);
15   NumericMatrix AMr1r(AMr1s);
16   NumericMatrix orr(os);
17   NumericMatrix mr(ms);
18   NumericVector ir(ivec);
19   NumericVector emburn(emburns);
20   NumericVector p2sr(p2ss);
21 
22   NumericVector emprir(empris);
23   // NumericVector frontend(fends);
24   NumericVector allthetas(alls);
25   NumericVector autopri(autos);
26 
27   int p2s = p2sr(0), empri = emprir(0);
28   int n = xr.nrow(), k = xr.ncol();
29   int const AMn = n;
30   int npatt = orr.nrow();
31   int cvalue = 1;
32 
33   arma::mat x(xr.begin(), n, k, false);
34   arma::mat thetaold(thetar.begin(), k + 1, k + 1, false);
35   arma::mat AMr1(AMr1r.begin(), n, k, false);
36   arma::mat obsmat(orr.begin(), npatt, k, false);
37   arma::mat mismat(mr.begin(), npatt, k, false);
38   arma::vec ii(ir.begin(), n, false);
39   //Rcpp::Rcout << "Set up arma things. "  << std::endl;
40 
41   // Bring out your priors.
42   NumericMatrix prr;
43   int npr, knr;
44   arma::mat priors;
45   if (!Rf_isNull(prs)) {
46     prr = NumericMatrix(prs);
47     npr = prr.nrow();
48     knr = prr.ncol();
49     priors = arma::mat(prr.begin(), npr, knr, false);
50   }
51 
52   int count = 0;
53   int is, isp;
54 
55   //int nparam = arma::accu(arma::find(arma::trimatu(thetaold)));
56 
57   arma::uvec upperpos = arma::find(arma::trimatu(arma::ones<arma::mat>(k+1,k+1)));
58   arma::mat xplay = arma::zeros<arma::mat>(AMn,k);
59   arma::mat hmcv(k,k);
60   arma::mat imputations(2,k);
61   arma::vec music(k);
62   arma::mat thetanew(k+1, k+1);
63   arma::mat theta(k+1, k+1);
64   arma::vec sweeppos(k+1);
65   arma::uvec mispos;
66   arma::uvec thetaleft;
67   arma::vec etest;
68   arma::mat iterHist(1,3);
69   arma::mat thetaHolder(upperpos.n_elem,1);
70   thetaHolder.col(0) = thetaold.elem(upperpos);
71 
72   iterHist.zeros();
73   sweeppos.zeros();
74   hmcv.zeros();
75   music.zeros();
76   int st, ss, singFlag, monoFlag;
77   if (arma::accu(mismat.row(0)) == 0) {
78     st = 1;
79   } else {
80     st = 0;
81   }
82   //  if (empri > 0) {
83   arma::mat hold = empri * arma::eye(k,k);
84   arma::mat simple(k,k);
85 //}
86 
87   if (p2s > 0) Rcpp::Rcout << std::endl;
88   //Rcpp::Rcout << "Starting loop. "  << std::endl;
89 while ( ( (cvalue > 0) | (count < emburn(0)) )  & ( (count < emburn(1)) | (emburn(1) < 1))) {
90     count++;
91     hmcv.zeros(k,k);
92     music.zeros(k);
93     xplay.zeros(AMn,k);
94 
95     if (p2s > 0) {
96       if (count < 10) {
97         Rcpp::Rcout << "  " << count;
98       } else {
99         Rcpp::Rcout << " " << count;
100       }
101       if (count % 20 == 0) {
102         Rcpp::Rcout << std::endl;
103       }
104     }
105 
106     if (st == 1) {
107       xplay.rows(0,ii(1)-2) = x.rows(0,ii(1)-2);
108     }
109     if (Rf_isNull(prs)) {
110       for (ss = st; ss < obsmat.n_rows; ss++) {
111 
112         is = ii(ss)-1;
113         isp = ii(ss+1)-2;
114 
115         theta = thetaold;
116         sweeppos.zeros();
117         sweeppos(arma::span(1,k)) = arma::trans(obsmat.row(ss));
118 
119         sweep(theta, sweeppos);
120 
121         imputations.zeros();
122         imputations.set_size(isp - is, k);
123 
124         imputations = x.rows(is, isp) * theta(arma::span(1,k), arma::span(1,k));
125         imputations.each_row() += theta(0, arma::span(1,k));
126         imputations = AMr1.rows(is, isp) % imputations;
127 
128         xplay.rows(is, isp) = x.rows(is, isp) + imputations;
129 
130         mispos = arma::find(mismat.row(ss));
131         hmcv(mispos, mispos) += (1+ isp - is) *  theta(mispos+1, mispos+1);
132 
133 
134 
135 
136       }
137     } else {
138       for (ss = st; ss < obsmat.n_rows; ss++) {
139 
140         is = ii(ss)-1;
141         isp = ii(ss+1)-2;
142 
143         theta = thetaold;
144         sweeppos.zeros();
145         sweeppos(arma::span(1,k)) = arma::trans(obsmat.row(ss));
146 
147         sweep(theta, sweeppos);
148 
149         imputations.zeros();
150         imputations.set_size(isp - is, k);
151 
152         imputations = x.rows(is, isp) * theta(arma::span(1,k), arma::span(1,k));
153         imputations.each_row() += theta(0, arma::span(1,k));
154         imputations = AMr1.rows(is, isp) % imputations;
155 
156         mispos = arma::find(mismat.row(ss));
157         arma::mat solveSigma = arma::inv(theta(mispos + 1, mispos + 1));
158         arma::mat diagLambda = arma::zeros<arma::mat>(mispos.n_elem, mispos.n_elem);
159         for (int p = 0; p <= isp-is; p++) {
160           arma::uvec prRow = arma::find(priors.col(0) == p + is + 1);
161           if (prRow.n_elem > 0) {
162             arma::uvec pu(1);
163             pu(0) = p;
164             arma::mat thisPrior = priors.rows(prRow);
165             arma::uvec theseCols = arma::conv_to<arma::uvec>::from(thisPrior.col(1)-1);
166             arma::vec prHolder = arma::zeros<arma::vec>(k);
167             prHolder.elem(theseCols) = thisPrior.col(3);
168             diagLambda.diag() = prHolder.elem(mispos);
169             arma::mat wvar = arma::inv(diagLambda + solveSigma);
170             prHolder.elem(theseCols) = thisPrior.col(2);
171             arma::mat firstInner = solveSigma * arma::trans(imputations(pu, mispos));
172             arma::mat secondInner = prHolder.elem(mispos);
173             arma::mat muMiss = wvar * (secondInner + firstInner);
174 
175             imputations(pu, mispos) = arma::trans(muMiss);
176             hmcv(mispos, mispos) +=  wvar;
177           } else {
178             hmcv(mispos, mispos) += theta(mispos + 1, mispos + 1);
179           }
180 
181         }
182         xplay.rows(is, isp) = x.rows(is, isp) + imputations;
183       }
184 
185     }
186 
187     hmcv += xplay.t() * xplay;
188     music += arma::trans(arma::sum(xplay));
189     if (empri > 0) {
190       simple = (music * arma::trans(music))/AMn;
191       hmcv = (( (double)AMn/(AMn+empri+k+2)) * (hmcv - simple + hold)) + simple;
192     }
193 
194     thetanew(0,0) = AMn;
195     thetanew(0, arma::span(1,k)) = arma::trans(music);
196     thetanew(arma::span(1,k), 0) = music;
197     thetanew(arma::span(1,k), arma::span(1,k)) = hmcv;
198     thetanew = thetanew/AMn;
199     thetanew = symmatu(thetanew);
200 
201     sweeppos.zeros();
202     sweeppos(0) = 1;
203     sweep(thetanew, sweeppos);
204     theta = arma::abs(thetanew - thetaold);
205     thetaleft = arma::find(arma::trimatu(theta) > tol(0));
206     cvalue = thetaleft.n_elem;
207     thetaold = thetanew;
208 
209     if (cvalue > iterHist(count-1,0) && count > 20) {
210       monoFlag = 1;
211       if (autopri(0) > 0) {
212         if (arma::accu(iterHist(arma::span(count - 20, count - 1), 2)) > 3) {
213           if (empri < (autopri(0) * (double)n)) {
214             empri = empri + 0.01 * (double)n;
215           }
216         }
217       }
218     } else {
219       monoFlag = 0;
220     }
221 
222     etest = arma::eig_sym(thetaold(arma::span(1,k), arma::span(1,k)));
223 
224     if (arma::accu(etest <= 0)) {
225       singFlag = 1;
226     } else {
227       singFlag = 0;
228     }
229     if (p2s > 1) {
230       Rcpp::Rcout << "(" << cvalue << ")";
231       if (monoFlag == 1) {
232         Rcpp::Rcout << "*";
233       }
234       if (singFlag == 1) {
235         Rcpp::Rcout << "!";
236       }
237     }
238     iterHist.resize(iterHist.n_rows+1, iterHist.n_cols);
239     iterHist(count, 0) = cvalue;
240     iterHist(count, 1) = monoFlag;
241     iterHist(count, 2) = singFlag;
242     if (allthetas(0) == 1) {
243       thetaHolder.resize(thetaHolder.n_rows, thetaHolder.n_cols + 1);
244       thetaHolder.col(count) = thetaold.elem(upperpos);
245     }
246   }
247   iterHist.shed_row(0);
248 
249   if (p2s > 0) Rcpp::Rcout << std::endl;
250   List z;
251   if (allthetas(0) == 1) {
252     thetaHolder.shed_row(0);
253     z = List::create(Rcpp::Named("thetanew") = thetaHolder,
254                      Rcpp::Named("iter.hist") = iterHist);
255   } else {
256     z = List::create(Rcpp::Named("theta") = thetaold,
257                      Rcpp::Named("iter.hist") = iterHist);
258   }
259   return z ;
260 }
261 
262 // void sweep(arma::mat& g, arma::vec m) {
263 //   int p = g.n_rows, h, j, i;
264 //   arma::uvec k = arma::find(m);
265 
266 //   if (k.n_elem == p) {
267 //     g = -arma::inv(g);
268 //   } else {
269 //     for (h = 0; h < k.n_rows; h++) {
270 //       for (j = 0; j < p; j++) {
271 //         for (i = 0; i <= j; i++) {
272 //           if (i == k(h)) {
273 //             if (j == k(h)) {
274 //               g(i,j) = -1/g(i,j);
275 //               //Rcpp::Rcout << k(h) << ": " << "(i,j): (" <<i<<", "<<j<<"): " << g(i,j) << std::endl;
276 
277 //             } else {
278 //               //Rcpp::Rcout << k(h) << ": " << "(i,j): (" <<i<<", "<<j<<"): " << g(i,j)<<"\t" << g(i,i) << std::endl;
279 //               g(i,j) = g(i,j)/g(i,i);
280 //             }
281 //           } else {
282 //             //Rcpp::Rcout << k(h) << ": " << "(i,j): (" <<i<<", "<<j<<"): " << g(i,j) << std::endl;
283 //             g(i,j) = g(i,j) - g(i, k(h)) * g(k(h), j)/g(k(h), k(h));
284 //           }
285 //         }
286 //       }
287 //     }
288 //     g = arma::symmatu(g);
289 //   }
290 
291 // }
292 
sweep(arma::mat & g,arma::vec m)293 void sweep(arma::mat& g, arma::vec m) {
294   int p = g.n_rows;
295   arma::uvec k = arma::find(m);
296   arma::uvec kcompl = arma::find(1-m);
297   if (k.n_elem == p) {
298     g = -arma::inv_sympd(g);
299   } else {
300     arma::mat h = g(k, k);
301     try {
302       g(k,k) = arma::inv_sympd(h);
303     } catch (std::runtime_error &e){
304       g(k,k) = arma::pinv(h, sqrt(arma::datum::eps));
305     } catch (...) {
306       Rcpp::Rcout << "Caught an unknown exception\n";
307     }
308     g(k,kcompl) = g(k,k) * g(k,kcompl);
309     g(kcompl, kcompl) = g(kcompl, kcompl) - (g(kcompl, k) * g(k,kcompl));
310     g(kcompl,k) = arma::trans(g(k,kcompl));
311     g(k,k) = -g(k,k);
312     g = symmatu(g);
313   }
314 
315 }
316 
317 
ameliaImpute(SEXP xs,SEXP AMr1s,SEXP os,SEXP ms,SEXP ivec,SEXP thetas,SEXP prs,SEXP bdss,SEXP maxres)318 SEXP ameliaImpute(SEXP xs, SEXP AMr1s, SEXP os, SEXP ms, SEXP ivec, SEXP thetas,  SEXP prs, SEXP bdss, SEXP maxres){
319 
320   //, SEXP p2ss, SEXP prs, SEXP empris, SEXP fends, SEXP alls, SEXP autos, SEXP emburns//
321 
322   NumericMatrix xr(xs);
323   NumericMatrix thetar(thetas);
324   NumericMatrix AMr1r(AMr1s);
325   NumericMatrix orr(os);
326   NumericMatrix mr(ms);
327   NumericVector ir(ivec);
328   NumericMatrix bdr;
329   NumericVector maxrr;
330   int maxsamples;
331 
332   int n = xr.nrow(), k = xr.ncol();
333   int const AMn = n;
334   int npatt = orr.nrow();
335 
336   arma::mat x(xr.begin(), n, k, false);
337   arma::mat thetaold(thetar.begin(), k + 1, k + 1, false);
338   arma::mat AMr1(AMr1r.begin(), n, k, false);
339   arma::mat obsmat(orr.begin(), npatt, k, false);
340   arma::mat mismat(mr.begin(), npatt, k, false);
341   arma::vec ii(ir.begin(), npatt + 1, false);
342 
343   // Bring out your priors.
344   NumericMatrix prr;
345   int npr, knr;
346   arma::mat priors;
347   if (!Rf_isNull(prs)) {
348     prr = NumericMatrix(prs);
349     npr = prr.nrow();
350     knr = prr.ncol();
351     priors = arma::mat(prr.begin(), npr, knr, false);
352   }
353 
354   arma::mat bounds;
355   if (!Rf_isNull(bdss)) {
356     bdr = NumericMatrix(bdss);
357     bounds = arma::mat(bdr.begin(), bdr.nrow(), bdr.ncol(), false);
358     maxrr = NumericVector(maxres);
359     maxsamples = maxrr(0);
360   }
361 
362   int is, isp;
363 
364   //int nparam = arma::accu(arma::find(arma::trimatu(thetaold)));
365 
366   //arma::uvec upperpos = arma::find(arma::trimatu(arma::abs(arma::randu<arma::mat>(k+1,k+1))));
367   arma::mat xplay = arma::zeros<arma::mat>(AMn,k);
368   arma::mat imputations(2,k);
369   arma::mat theta(k+1, k+1);
370   arma::mat junk(2,k);
371   arma::mat Ci(k, k);
372   arma::vec sweeppos(k+1);
373   arma::uvec mispos;
374 
375   sweeppos.zeros();
376   int st, ss;
377   if (arma::accu(mismat.row(0)) == 0) {
378     st = 1;
379   } else {
380     st = 0;
381   }
382 
383   if (st == 1) {
384     xplay.rows(0,ii(1)-2) = x.rows(0,ii(1)-2);
385   }
386 
387   if (Rf_isNull(prs)) {
388     for (ss = st; ss < obsmat.n_rows; ss++) {
389       is = ii(ss)-1;
390       isp = ii(ss+1)-2;
391 
392       theta = thetaold;
393       sweeppos.zeros();
394       sweeppos(arma::span(1,k)) = arma::trans(obsmat.row(ss));
395 
396       sweep(theta, sweeppos);
397 
398       mispos = arma::find(mismat.row(ss));
399       Ci.zeros(k, k);
400       Ci(mispos, mispos) = chol(theta(mispos+1, mispos + 1));
401       junk = Rcpp::rnorm((isp - is + 1)* k, 0, 1);
402       junk.reshape(isp - is +1, k);
403       junk = junk * Ci;
404 
405       imputations.zeros();
406       imputations.set_size(isp - is, k);
407       imputations = x.rows(is, isp) * theta(arma::span(1,k), arma::span(1,k));
408       imputations.each_row() += theta(0, arma::span(1,k));
409       imputations = AMr1.rows(is, isp) % imputations;
410 
411       if (Rf_isNull(bdss)) {
412         xplay.rows(is, isp) = x.rows(is, isp) + imputations + junk;
413       } else {
414         xplay.rows(is, isp) = resampler(x.rows(is, isp), Ci, imputations, mispos, bounds, maxsamples);
415       }
416 
417     }
418   } else {
419     for (ss = st; ss < obsmat.n_rows; ss++) {
420       is = ii(ss)-1;
421       isp = ii(ss+1)-2;
422 
423       theta = thetaold;
424       sweeppos.zeros();
425       sweeppos(arma::span(1,k)) = arma::trans(obsmat.row(ss));
426 
427       sweep(theta, sweeppos);
428       junk.zeros(isp - is + 1, k);
429       junk = Rcpp::rnorm((isp - is + 1)* k, 0, 1);
430       junk.reshape(isp - is +1, k);
431       imputations.zeros();
432       imputations.set_size(isp - is, k);
433 
434       imputations = x.rows(is, isp) * theta(arma::span(1,k), arma::span(1,k));
435       imputations.each_row() += theta(0, arma::span(1,k));
436       imputations = AMr1.rows(is, isp) % imputations;
437 
438       mispos = arma::find(mismat.row(ss));
439       arma::mat solveSigma = arma::inv(theta(mispos + 1, mispos + 1));
440       arma::mat diagLambda = arma::zeros<arma::mat>(mispos.n_elem, mispos.n_elem);
441       for (int p = 0; p <= isp-is; p++) {
442         arma::uvec prRow = arma::find(priors.col(0) == p + is + 1);
443         Ci.zeros(k,k);
444 
445         if (prRow.n_elem > 0) {
446           arma::uvec pu(1);
447           pu(0) = p;
448           arma::mat thisPrior = priors.rows(prRow);
449           arma::uvec theseCols = arma::conv_to<arma::uvec>::from(thisPrior.col(1)-1);
450           arma::vec prHolder = arma::zeros<arma::vec>(k);
451           prHolder.elem(theseCols) = thisPrior.col(3);
452           diagLambda.diag() = prHolder.elem(mispos);
453           arma::mat wvar = arma::inv(diagLambda + solveSigma);
454           prHolder.elem(theseCols) = thisPrior.col(2);
455           arma::mat muMiss = wvar * (prHolder.elem(mispos) + solveSigma * arma::trans(imputations(pu, mispos)));
456           imputations(pu, mispos) = arma::trans(muMiss);
457           Ci(mispos, mispos) = chol(wvar);
458         } else {
459           Ci(mispos, mispos) = chol(theta(mispos + 1, mispos + 1));
460         }
461         junk.row(p) = junk.row(p) * Ci;
462         if (Rf_isNull(bdss)) {
463           xplay.row(is + p) = x.row(is + p) + imputations.row(p) + junk.row(p);
464         } else {
465           xplay.row(is + p) = resampler(x.row(is + p), Ci, imputations.row(p), mispos, bounds, maxsamples);
466         }
467 
468       }
469     }
470 
471   }
472 
473 
474   return wrap(xplay);
475 }
476 
resampler(arma::mat x,arma::mat ci,arma::mat imps,arma::uvec mss,arma::mat bounds,int maxsample)477 arma::mat resampler(arma::mat x, arma::mat ci, arma::mat imps, arma::uvec mss,
478                     arma::mat bounds, int maxsample) {
479   int nss = x.n_rows, k = x.n_cols;
480 
481   arma::mat ub(nss, k);
482   arma::mat lb(nss, k);
483   arma::umat utest;
484   arma::umat ltest;
485   ub.fill(arma::datum::inf);
486   lb.fill(-arma::datum::inf);
487   arma::mat xp = arma::zeros<arma::mat>(nss, k);
488 
489   arma::mat junk = Rcpp::rnorm(nss * k, 0, 1);
490   junk.reshape(nss, k);
491   junk = junk * ci;
492 
493   int nb = 0, bdvar;
494   for (int j = 0; j < bounds.n_rows; j++) {
495     bdvar = (int) bounds(j,0) - 1;
496     if (arma::accu(mss == bdvar)) {
497       nb++;
498       lb.col(bdvar) = arma::ones<arma::colvec>(nss) * bounds(j,1);
499       ub.col(bdvar) = arma::ones<arma::colvec>(nss) * bounds(j,2);
500     }
501   }
502 
503 
504   if (nb == 0) {
505     return x + imps + junk;
506   }
507 
508   //Rcpp::Rcout << ub << std::endl;
509   int samp = 0;
510   arma::colvec done = arma::zeros<arma::colvec>(nss);
511   arma::colvec left = arma::ones<arma::colvec>(nss);
512   arma::uvec finished;
513   while ((arma::accu(left) > 0) & (samp < maxsample)) {
514     samp++;
515     utest = (imps + junk) > ub;
516     ltest = (imps + junk) < lb;
517 
518     done += left % (arma::sum(utest + ltest, 1) == 0);
519     finished = arma::find(left % (arma::sum(utest + ltest, 1) == 0));
520     left -= left % (arma::sum(utest + ltest,  1) == 0);
521 
522 
523     ub.rows(finished).fill(arma::datum::inf);
524     lb.rows(finished).fill(-arma::datum::inf);
525     xp.rows(finished) = x.rows(finished) + imps.rows(finished) + junk.rows(finished);
526 
527     junk = Rcpp::rnorm(nss * k, 0, 1);
528     junk.reshape(nss, k);
529     junk = junk * ci;
530 
531   }
532 
533   if (arma::accu(left) > 0) {
534     xp.rows(arma::find(left)) = x.rows(arma::find(left)) + imps.rows(arma::find(left)) + junk.rows(arma::find(left));
535     utest = (imps + junk) > ub;
536     ltest = (imps + junk) < lb;
537     arma::uvec ufails = arma::find(utest);
538     arma::uvec lfails = arma::find(ltest);
539     xp.elem(ufails) = ub.elem(ufails);
540     xp.elem(lfails) = lb.elem(lfails);
541   }
542 
543   return xp;
544 
545 }
546