1
2 #include "em.h"
3 #include <RcppArmadillo.h>
4
5
6 using namespace Rcpp ;
7
emcore(SEXP xs,SEXP AMr1s,SEXP os,SEXP ms,SEXP ivec,SEXP thetas,SEXP tols,SEXP emburns,SEXP p2ss,SEXP empris,SEXP autos,SEXP alls,SEXP prs)8 SEXP emcore(SEXP xs, SEXP AMr1s, SEXP os, SEXP ms, SEXP ivec, SEXP thetas, SEXP tols, SEXP emburns, SEXP p2ss, SEXP empris, SEXP autos, SEXP alls, SEXP prs){
9
10 //, SEXP p2ss, SEXP prs, SEXP empris, SEXP fends, SEXP alls, SEXP autos, SEXP emburns//
11
12 NumericMatrix xr(xs);
13 NumericMatrix thetar(thetas);
14 NumericVector tol(tols);
15 NumericMatrix AMr1r(AMr1s);
16 NumericMatrix orr(os);
17 NumericMatrix mr(ms);
18 NumericVector ir(ivec);
19 NumericVector emburn(emburns);
20 NumericVector p2sr(p2ss);
21
22 NumericVector emprir(empris);
23 // NumericVector frontend(fends);
24 NumericVector allthetas(alls);
25 NumericVector autopri(autos);
26
27 int p2s = p2sr(0), empri = emprir(0);
28 int n = xr.nrow(), k = xr.ncol();
29 int const AMn = n;
30 int npatt = orr.nrow();
31 int cvalue = 1;
32
33 arma::mat x(xr.begin(), n, k, false);
34 arma::mat thetaold(thetar.begin(), k + 1, k + 1, false);
35 arma::mat AMr1(AMr1r.begin(), n, k, false);
36 arma::mat obsmat(orr.begin(), npatt, k, false);
37 arma::mat mismat(mr.begin(), npatt, k, false);
38 arma::vec ii(ir.begin(), n, false);
39 //Rcpp::Rcout << "Set up arma things. " << std::endl;
40
41 // Bring out your priors.
42 NumericMatrix prr;
43 int npr, knr;
44 arma::mat priors;
45 if (!Rf_isNull(prs)) {
46 prr = NumericMatrix(prs);
47 npr = prr.nrow();
48 knr = prr.ncol();
49 priors = arma::mat(prr.begin(), npr, knr, false);
50 }
51
52 int count = 0;
53 int is, isp;
54
55 //int nparam = arma::accu(arma::find(arma::trimatu(thetaold)));
56
57 arma::uvec upperpos = arma::find(arma::trimatu(arma::ones<arma::mat>(k+1,k+1)));
58 arma::mat xplay = arma::zeros<arma::mat>(AMn,k);
59 arma::mat hmcv(k,k);
60 arma::mat imputations(2,k);
61 arma::vec music(k);
62 arma::mat thetanew(k+1, k+1);
63 arma::mat theta(k+1, k+1);
64 arma::vec sweeppos(k+1);
65 arma::uvec mispos;
66 arma::uvec thetaleft;
67 arma::vec etest;
68 arma::mat iterHist(1,3);
69 arma::mat thetaHolder(upperpos.n_elem,1);
70 thetaHolder.col(0) = thetaold.elem(upperpos);
71
72 iterHist.zeros();
73 sweeppos.zeros();
74 hmcv.zeros();
75 music.zeros();
76 int st, ss, singFlag, monoFlag;
77 if (arma::accu(mismat.row(0)) == 0) {
78 st = 1;
79 } else {
80 st = 0;
81 }
82 // if (empri > 0) {
83 arma::mat hold = empri * arma::eye(k,k);
84 arma::mat simple(k,k);
85 //}
86
87 if (p2s > 0) Rcpp::Rcout << std::endl;
88 //Rcpp::Rcout << "Starting loop. " << std::endl;
89 while ( ( (cvalue > 0) | (count < emburn(0)) ) & ( (count < emburn(1)) | (emburn(1) < 1))) {
90 count++;
91 hmcv.zeros(k,k);
92 music.zeros(k);
93 xplay.zeros(AMn,k);
94
95 if (p2s > 0) {
96 if (count < 10) {
97 Rcpp::Rcout << " " << count;
98 } else {
99 Rcpp::Rcout << " " << count;
100 }
101 if (count % 20 == 0) {
102 Rcpp::Rcout << std::endl;
103 }
104 }
105
106 if (st == 1) {
107 xplay.rows(0,ii(1)-2) = x.rows(0,ii(1)-2);
108 }
109 if (Rf_isNull(prs)) {
110 for (ss = st; ss < obsmat.n_rows; ss++) {
111
112 is = ii(ss)-1;
113 isp = ii(ss+1)-2;
114
115 theta = thetaold;
116 sweeppos.zeros();
117 sweeppos(arma::span(1,k)) = arma::trans(obsmat.row(ss));
118
119 sweep(theta, sweeppos);
120
121 imputations.zeros();
122 imputations.set_size(isp - is, k);
123
124 imputations = x.rows(is, isp) * theta(arma::span(1,k), arma::span(1,k));
125 imputations.each_row() += theta(0, arma::span(1,k));
126 imputations = AMr1.rows(is, isp) % imputations;
127
128 xplay.rows(is, isp) = x.rows(is, isp) + imputations;
129
130 mispos = arma::find(mismat.row(ss));
131 hmcv(mispos, mispos) += (1+ isp - is) * theta(mispos+1, mispos+1);
132
133
134
135
136 }
137 } else {
138 for (ss = st; ss < obsmat.n_rows; ss++) {
139
140 is = ii(ss)-1;
141 isp = ii(ss+1)-2;
142
143 theta = thetaold;
144 sweeppos.zeros();
145 sweeppos(arma::span(1,k)) = arma::trans(obsmat.row(ss));
146
147 sweep(theta, sweeppos);
148
149 imputations.zeros();
150 imputations.set_size(isp - is, k);
151
152 imputations = x.rows(is, isp) * theta(arma::span(1,k), arma::span(1,k));
153 imputations.each_row() += theta(0, arma::span(1,k));
154 imputations = AMr1.rows(is, isp) % imputations;
155
156 mispos = arma::find(mismat.row(ss));
157 arma::mat solveSigma = arma::inv(theta(mispos + 1, mispos + 1));
158 arma::mat diagLambda = arma::zeros<arma::mat>(mispos.n_elem, mispos.n_elem);
159 for (int p = 0; p <= isp-is; p++) {
160 arma::uvec prRow = arma::find(priors.col(0) == p + is + 1);
161 if (prRow.n_elem > 0) {
162 arma::uvec pu(1);
163 pu(0) = p;
164 arma::mat thisPrior = priors.rows(prRow);
165 arma::uvec theseCols = arma::conv_to<arma::uvec>::from(thisPrior.col(1)-1);
166 arma::vec prHolder = arma::zeros<arma::vec>(k);
167 prHolder.elem(theseCols) = thisPrior.col(3);
168 diagLambda.diag() = prHolder.elem(mispos);
169 arma::mat wvar = arma::inv(diagLambda + solveSigma);
170 prHolder.elem(theseCols) = thisPrior.col(2);
171 arma::mat firstInner = solveSigma * arma::trans(imputations(pu, mispos));
172 arma::mat secondInner = prHolder.elem(mispos);
173 arma::mat muMiss = wvar * (secondInner + firstInner);
174
175 imputations(pu, mispos) = arma::trans(muMiss);
176 hmcv(mispos, mispos) += wvar;
177 } else {
178 hmcv(mispos, mispos) += theta(mispos + 1, mispos + 1);
179 }
180
181 }
182 xplay.rows(is, isp) = x.rows(is, isp) + imputations;
183 }
184
185 }
186
187 hmcv += xplay.t() * xplay;
188 music += arma::trans(arma::sum(xplay));
189 if (empri > 0) {
190 simple = (music * arma::trans(music))/AMn;
191 hmcv = (( (double)AMn/(AMn+empri+k+2)) * (hmcv - simple + hold)) + simple;
192 }
193
194 thetanew(0,0) = AMn;
195 thetanew(0, arma::span(1,k)) = arma::trans(music);
196 thetanew(arma::span(1,k), 0) = music;
197 thetanew(arma::span(1,k), arma::span(1,k)) = hmcv;
198 thetanew = thetanew/AMn;
199 thetanew = symmatu(thetanew);
200
201 sweeppos.zeros();
202 sweeppos(0) = 1;
203 sweep(thetanew, sweeppos);
204 theta = arma::abs(thetanew - thetaold);
205 thetaleft = arma::find(arma::trimatu(theta) > tol(0));
206 cvalue = thetaleft.n_elem;
207 thetaold = thetanew;
208
209 if (cvalue > iterHist(count-1,0) && count > 20) {
210 monoFlag = 1;
211 if (autopri(0) > 0) {
212 if (arma::accu(iterHist(arma::span(count - 20, count - 1), 2)) > 3) {
213 if (empri < (autopri(0) * (double)n)) {
214 empri = empri + 0.01 * (double)n;
215 }
216 }
217 }
218 } else {
219 monoFlag = 0;
220 }
221
222 etest = arma::eig_sym(thetaold(arma::span(1,k), arma::span(1,k)));
223
224 if (arma::accu(etest <= 0)) {
225 singFlag = 1;
226 } else {
227 singFlag = 0;
228 }
229 if (p2s > 1) {
230 Rcpp::Rcout << "(" << cvalue << ")";
231 if (monoFlag == 1) {
232 Rcpp::Rcout << "*";
233 }
234 if (singFlag == 1) {
235 Rcpp::Rcout << "!";
236 }
237 }
238 iterHist.resize(iterHist.n_rows+1, iterHist.n_cols);
239 iterHist(count, 0) = cvalue;
240 iterHist(count, 1) = monoFlag;
241 iterHist(count, 2) = singFlag;
242 if (allthetas(0) == 1) {
243 thetaHolder.resize(thetaHolder.n_rows, thetaHolder.n_cols + 1);
244 thetaHolder.col(count) = thetaold.elem(upperpos);
245 }
246 }
247 iterHist.shed_row(0);
248
249 if (p2s > 0) Rcpp::Rcout << std::endl;
250 List z;
251 if (allthetas(0) == 1) {
252 thetaHolder.shed_row(0);
253 z = List::create(Rcpp::Named("thetanew") = thetaHolder,
254 Rcpp::Named("iter.hist") = iterHist);
255 } else {
256 z = List::create(Rcpp::Named("theta") = thetaold,
257 Rcpp::Named("iter.hist") = iterHist);
258 }
259 return z ;
260 }
261
262 // void sweep(arma::mat& g, arma::vec m) {
263 // int p = g.n_rows, h, j, i;
264 // arma::uvec k = arma::find(m);
265
266 // if (k.n_elem == p) {
267 // g = -arma::inv(g);
268 // } else {
269 // for (h = 0; h < k.n_rows; h++) {
270 // for (j = 0; j < p; j++) {
271 // for (i = 0; i <= j; i++) {
272 // if (i == k(h)) {
273 // if (j == k(h)) {
274 // g(i,j) = -1/g(i,j);
275 // //Rcpp::Rcout << k(h) << ": " << "(i,j): (" <<i<<", "<<j<<"): " << g(i,j) << std::endl;
276
277 // } else {
278 // //Rcpp::Rcout << k(h) << ": " << "(i,j): (" <<i<<", "<<j<<"): " << g(i,j)<<"\t" << g(i,i) << std::endl;
279 // g(i,j) = g(i,j)/g(i,i);
280 // }
281 // } else {
282 // //Rcpp::Rcout << k(h) << ": " << "(i,j): (" <<i<<", "<<j<<"): " << g(i,j) << std::endl;
283 // g(i,j) = g(i,j) - g(i, k(h)) * g(k(h), j)/g(k(h), k(h));
284 // }
285 // }
286 // }
287 // }
288 // g = arma::symmatu(g);
289 // }
290
291 // }
292
sweep(arma::mat & g,arma::vec m)293 void sweep(arma::mat& g, arma::vec m) {
294 int p = g.n_rows;
295 arma::uvec k = arma::find(m);
296 arma::uvec kcompl = arma::find(1-m);
297 if (k.n_elem == p) {
298 g = -arma::inv_sympd(g);
299 } else {
300 arma::mat h = g(k, k);
301 try {
302 g(k,k) = arma::inv_sympd(h);
303 } catch (std::runtime_error &e){
304 g(k,k) = arma::pinv(h, sqrt(arma::datum::eps));
305 } catch (...) {
306 Rcpp::Rcout << "Caught an unknown exception\n";
307 }
308 g(k,kcompl) = g(k,k) * g(k,kcompl);
309 g(kcompl, kcompl) = g(kcompl, kcompl) - (g(kcompl, k) * g(k,kcompl));
310 g(kcompl,k) = arma::trans(g(k,kcompl));
311 g(k,k) = -g(k,k);
312 g = symmatu(g);
313 }
314
315 }
316
317
ameliaImpute(SEXP xs,SEXP AMr1s,SEXP os,SEXP ms,SEXP ivec,SEXP thetas,SEXP prs,SEXP bdss,SEXP maxres)318 SEXP ameliaImpute(SEXP xs, SEXP AMr1s, SEXP os, SEXP ms, SEXP ivec, SEXP thetas, SEXP prs, SEXP bdss, SEXP maxres){
319
320 //, SEXP p2ss, SEXP prs, SEXP empris, SEXP fends, SEXP alls, SEXP autos, SEXP emburns//
321
322 NumericMatrix xr(xs);
323 NumericMatrix thetar(thetas);
324 NumericMatrix AMr1r(AMr1s);
325 NumericMatrix orr(os);
326 NumericMatrix mr(ms);
327 NumericVector ir(ivec);
328 NumericMatrix bdr;
329 NumericVector maxrr;
330 int maxsamples;
331
332 int n = xr.nrow(), k = xr.ncol();
333 int const AMn = n;
334 int npatt = orr.nrow();
335
336 arma::mat x(xr.begin(), n, k, false);
337 arma::mat thetaold(thetar.begin(), k + 1, k + 1, false);
338 arma::mat AMr1(AMr1r.begin(), n, k, false);
339 arma::mat obsmat(orr.begin(), npatt, k, false);
340 arma::mat mismat(mr.begin(), npatt, k, false);
341 arma::vec ii(ir.begin(), npatt + 1, false);
342
343 // Bring out your priors.
344 NumericMatrix prr;
345 int npr, knr;
346 arma::mat priors;
347 if (!Rf_isNull(prs)) {
348 prr = NumericMatrix(prs);
349 npr = prr.nrow();
350 knr = prr.ncol();
351 priors = arma::mat(prr.begin(), npr, knr, false);
352 }
353
354 arma::mat bounds;
355 if (!Rf_isNull(bdss)) {
356 bdr = NumericMatrix(bdss);
357 bounds = arma::mat(bdr.begin(), bdr.nrow(), bdr.ncol(), false);
358 maxrr = NumericVector(maxres);
359 maxsamples = maxrr(0);
360 }
361
362 int is, isp;
363
364 //int nparam = arma::accu(arma::find(arma::trimatu(thetaold)));
365
366 //arma::uvec upperpos = arma::find(arma::trimatu(arma::abs(arma::randu<arma::mat>(k+1,k+1))));
367 arma::mat xplay = arma::zeros<arma::mat>(AMn,k);
368 arma::mat imputations(2,k);
369 arma::mat theta(k+1, k+1);
370 arma::mat junk(2,k);
371 arma::mat Ci(k, k);
372 arma::vec sweeppos(k+1);
373 arma::uvec mispos;
374
375 sweeppos.zeros();
376 int st, ss;
377 if (arma::accu(mismat.row(0)) == 0) {
378 st = 1;
379 } else {
380 st = 0;
381 }
382
383 if (st == 1) {
384 xplay.rows(0,ii(1)-2) = x.rows(0,ii(1)-2);
385 }
386
387 if (Rf_isNull(prs)) {
388 for (ss = st; ss < obsmat.n_rows; ss++) {
389 is = ii(ss)-1;
390 isp = ii(ss+1)-2;
391
392 theta = thetaold;
393 sweeppos.zeros();
394 sweeppos(arma::span(1,k)) = arma::trans(obsmat.row(ss));
395
396 sweep(theta, sweeppos);
397
398 mispos = arma::find(mismat.row(ss));
399 Ci.zeros(k, k);
400 Ci(mispos, mispos) = chol(theta(mispos+1, mispos + 1));
401 junk = Rcpp::rnorm((isp - is + 1)* k, 0, 1);
402 junk.reshape(isp - is +1, k);
403 junk = junk * Ci;
404
405 imputations.zeros();
406 imputations.set_size(isp - is, k);
407 imputations = x.rows(is, isp) * theta(arma::span(1,k), arma::span(1,k));
408 imputations.each_row() += theta(0, arma::span(1,k));
409 imputations = AMr1.rows(is, isp) % imputations;
410
411 if (Rf_isNull(bdss)) {
412 xplay.rows(is, isp) = x.rows(is, isp) + imputations + junk;
413 } else {
414 xplay.rows(is, isp) = resampler(x.rows(is, isp), Ci, imputations, mispos, bounds, maxsamples);
415 }
416
417 }
418 } else {
419 for (ss = st; ss < obsmat.n_rows; ss++) {
420 is = ii(ss)-1;
421 isp = ii(ss+1)-2;
422
423 theta = thetaold;
424 sweeppos.zeros();
425 sweeppos(arma::span(1,k)) = arma::trans(obsmat.row(ss));
426
427 sweep(theta, sweeppos);
428 junk.zeros(isp - is + 1, k);
429 junk = Rcpp::rnorm((isp - is + 1)* k, 0, 1);
430 junk.reshape(isp - is +1, k);
431 imputations.zeros();
432 imputations.set_size(isp - is, k);
433
434 imputations = x.rows(is, isp) * theta(arma::span(1,k), arma::span(1,k));
435 imputations.each_row() += theta(0, arma::span(1,k));
436 imputations = AMr1.rows(is, isp) % imputations;
437
438 mispos = arma::find(mismat.row(ss));
439 arma::mat solveSigma = arma::inv(theta(mispos + 1, mispos + 1));
440 arma::mat diagLambda = arma::zeros<arma::mat>(mispos.n_elem, mispos.n_elem);
441 for (int p = 0; p <= isp-is; p++) {
442 arma::uvec prRow = arma::find(priors.col(0) == p + is + 1);
443 Ci.zeros(k,k);
444
445 if (prRow.n_elem > 0) {
446 arma::uvec pu(1);
447 pu(0) = p;
448 arma::mat thisPrior = priors.rows(prRow);
449 arma::uvec theseCols = arma::conv_to<arma::uvec>::from(thisPrior.col(1)-1);
450 arma::vec prHolder = arma::zeros<arma::vec>(k);
451 prHolder.elem(theseCols) = thisPrior.col(3);
452 diagLambda.diag() = prHolder.elem(mispos);
453 arma::mat wvar = arma::inv(diagLambda + solveSigma);
454 prHolder.elem(theseCols) = thisPrior.col(2);
455 arma::mat muMiss = wvar * (prHolder.elem(mispos) + solveSigma * arma::trans(imputations(pu, mispos)));
456 imputations(pu, mispos) = arma::trans(muMiss);
457 Ci(mispos, mispos) = chol(wvar);
458 } else {
459 Ci(mispos, mispos) = chol(theta(mispos + 1, mispos + 1));
460 }
461 junk.row(p) = junk.row(p) * Ci;
462 if (Rf_isNull(bdss)) {
463 xplay.row(is + p) = x.row(is + p) + imputations.row(p) + junk.row(p);
464 } else {
465 xplay.row(is + p) = resampler(x.row(is + p), Ci, imputations.row(p), mispos, bounds, maxsamples);
466 }
467
468 }
469 }
470
471 }
472
473
474 return wrap(xplay);
475 }
476
resampler(arma::mat x,arma::mat ci,arma::mat imps,arma::uvec mss,arma::mat bounds,int maxsample)477 arma::mat resampler(arma::mat x, arma::mat ci, arma::mat imps, arma::uvec mss,
478 arma::mat bounds, int maxsample) {
479 int nss = x.n_rows, k = x.n_cols;
480
481 arma::mat ub(nss, k);
482 arma::mat lb(nss, k);
483 arma::umat utest;
484 arma::umat ltest;
485 ub.fill(arma::datum::inf);
486 lb.fill(-arma::datum::inf);
487 arma::mat xp = arma::zeros<arma::mat>(nss, k);
488
489 arma::mat junk = Rcpp::rnorm(nss * k, 0, 1);
490 junk.reshape(nss, k);
491 junk = junk * ci;
492
493 int nb = 0, bdvar;
494 for (int j = 0; j < bounds.n_rows; j++) {
495 bdvar = (int) bounds(j,0) - 1;
496 if (arma::accu(mss == bdvar)) {
497 nb++;
498 lb.col(bdvar) = arma::ones<arma::colvec>(nss) * bounds(j,1);
499 ub.col(bdvar) = arma::ones<arma::colvec>(nss) * bounds(j,2);
500 }
501 }
502
503
504 if (nb == 0) {
505 return x + imps + junk;
506 }
507
508 //Rcpp::Rcout << ub << std::endl;
509 int samp = 0;
510 arma::colvec done = arma::zeros<arma::colvec>(nss);
511 arma::colvec left = arma::ones<arma::colvec>(nss);
512 arma::uvec finished;
513 while ((arma::accu(left) > 0) & (samp < maxsample)) {
514 samp++;
515 utest = (imps + junk) > ub;
516 ltest = (imps + junk) < lb;
517
518 done += left % (arma::sum(utest + ltest, 1) == 0);
519 finished = arma::find(left % (arma::sum(utest + ltest, 1) == 0));
520 left -= left % (arma::sum(utest + ltest, 1) == 0);
521
522
523 ub.rows(finished).fill(arma::datum::inf);
524 lb.rows(finished).fill(-arma::datum::inf);
525 xp.rows(finished) = x.rows(finished) + imps.rows(finished) + junk.rows(finished);
526
527 junk = Rcpp::rnorm(nss * k, 0, 1);
528 junk.reshape(nss, k);
529 junk = junk * ci;
530
531 }
532
533 if (arma::accu(left) > 0) {
534 xp.rows(arma::find(left)) = x.rows(arma::find(left)) + imps.rows(arma::find(left)) + junk.rows(arma::find(left));
535 utest = (imps + junk) > ub;
536 ltest = (imps + junk) < lb;
537 arma::uvec ufails = arma::find(utest);
538 arma::uvec lfails = arma::find(ltest);
539 xp.elem(ufails) = ub.elem(ufails);
540 xp.elem(lfails) = lb.elem(lfails);
541 }
542
543 return xp;
544
545 }
546