1 //////////////////////////////////////////////////////////////////////////
2 // cHDPHSMMnegbin.cc is C++ code to estimate a HDP-HSMM with Negative
3 // Binomial emission distribution (Johnson & Willsky, 2013)
4 //
5 // Matthew Blackwell
6 // Department of Government
7 // Harvard University
8 // mblackwell@gov.harvard.edu
9 
10 // Written 05/19/2017
11 //////////////////////////////////////////////////////////////////////////
12 
13 #ifndef CHDPHSMMNEGBIN_CC
14 #define CHDPHSMMNEGBIN_CC
15 
16 #include<vector>
17 #include<algorithm>
18 
19 #include "MCMCrng.h"
20 #include "MCMCfcds.h"
21 #include "matrix.h"
22 #include "distributions.h"
23 #include "stat.h"
24 #include "la.h"
25 #include "ide.h"
26 #include "smath.h"
27 #include "rng.h"
28 #include "mersenne.h"
29 #include "lecuyer.h"
30 
31 #include "MCMCnbutil.h"
32 
33 #include <R.h>
34 #include <R_ext/Utils.h>
35 
36 using namespace std;
37 using namespace scythe;
38 
39 
40 
41 struct hsmm_state {
42   Matrix<> s;
43   Matrix<> s_norep;
44   Matrix<> ps;
45   Matrix<> durs;
46   Matrix<> trans;
47   Matrix<> nstate;
48 };
49 
50 // define cumsumc
51 template <matrix_order RO, matrix_style RS, typename T,
52           matrix_order PO, matrix_style PS>
53 Matrix<T,RO,RS>
cumsumc(const Matrix<T,PO,PS> & A)54 cumsumc (const Matrix<T,PO,PS>& A)
55 {
56   Matrix<T,RO,RS> res (A.rows(), A.cols(), false);
57 
58   for (unsigned int j = 0; j < A.cols(); ++j) {
59     res(0,j) = A(0,j);
60     for (unsigned int i = 1; i < A.rows(); ++i) {
61       res(i,j) = A(i,j) + res(i-1,j);
62     }
63   }
64 
65 
66   return res;
67 }
68 
69 template <typename T, matrix_order O, matrix_style S>
70 Matrix<T,O,Concrete>
cumsumc(const Matrix<T,O,S> & A)71 cumsumc (const Matrix<T,O,S>& A)
72 {
73   return cumsumc<O,Concrete>(A);
74 }
75 
76 
77 
78 template <typename RNGTYPE>
negbin_hdphsmm_reg_state_sampler(rng<RNGTYPE> & stream,const int ns,const Matrix<> & Y,const Matrix<> & X,const Matrix<> & beta,const Matrix<> & P,const Matrix<> & nu,const Matrix<> & rho,const Matrix<> & omega,const double r)79 hsmm_state negbin_hdphsmm_reg_state_sampler(rng<RNGTYPE>& stream,
80                                             const int ns,
81                                             const Matrix<>& Y,
82                                             const Matrix<>& X,
83                                             const Matrix<>& beta,
84                                             const Matrix<>& P,
85                                             const Matrix<>& nu,
86                                             const Matrix<>& rho,
87                                             const Matrix<>& omega,
88                                             const double r){
89 
90   const int n = Y.rows();
91   Matrix<int> trans(ns, ns);
92   Matrix<int> nstate(ns, 1);
93   Matrix<> B(n, ns);
94   Matrix<> Bstar(n, ns);
95   Matrix<> pD(n, ns);
96   Matrix<> pDs(n, ns);
97   Matrix<> pr1 = ones(ns, 1);
98   Matrix<> py(n, ns);
99   Matrix<> cy(n, ns);
100   Matrix<> pstyt1(ns, 1);
101 
102   // no self-transitions
103   Matrix<> Pbar = P - P % eye(ns);
104   Pbar = log(Pbar) - log(Pbar * ones(ns, ns));
105   for (int t = 0; t < n; t++) {
106     int yt = (int) Y[t];
107     Matrix<> lambda = exp(X(t,_) * ::t(beta));
108     for (int j = 0; j< ns; ++j){
109       pD(t, j) = log(dnbinom(t+1, r, omega[j]));
110       pDs(t, j) = log(1 - pnbinom(t+1, r, omega[j]));
111       py(t, j) += lngammafn(rho[j] + yt) - lngammafn(rho[j]) - lngammafn(yt +1);
112       py(t, j) += rho[j] * log(rho[j]) + yt*log(lambda[j])- (rho[j] + yt) * log(rho[j] + lambda[j]);
113     }
114   }
115   B(n-1,_) = 0.0;
116   for (int t = (n-1); t >= 0; --t){
117     Matrix<> unnorm_pstyt(ns, 1);
118     Matrix<> pstyt(ns, 1);
119     Matrix<> censored(1,ns);
120     Matrix<> result((n-t),ns);
121     Matrix<> Bhold(1,ns);
122     Matrix<> cy = cumsumc(py(t, 0, n-1, ns-1));
123     result = B(t,0,n-1,ns-1) + cy + pD(0, 0, (n-t-1), ns-1);
124     Matrix<> maxes = maxc(result);
125     for (int j = 0; j < ns; j++) {
126       result(_,j) -= maxes[j];
127     }
128     Bstar(t,_) = log(sumc(exp(result))) + maxes;
129     censored = pDs(n-t-1,_) + cy(cy.rows()-1, _);
130     for (int j = 0; j < ns; j++) {
131       maxes[j] = std::max(Bstar(t,j), censored[j]);
132       Bstar(t,j) -= maxes[j];
133       censored[j] -= maxes[j];
134     }
135     Bstar(t,_) = log(exp(Bstar(t,_)) + exp(censored)) + maxes;
136     if (t > 0) {
137       for (int j = 0; j < ns; j++) {
138         Bhold = Pbar(j,_) + Bstar(t,_);
139         B(t-1,j) = log(sum(exp(Bhold - max(Bhold)))) + max(Bhold);
140       }
141     }
142   }
143 
144   int st;
145   Matrix<int> s(n, 1);
146   Matrix<bool> ch(n,1,true,false);
147   Matrix<int> durs(n,1);
148   Matrix<> pstyn(ns, 1);
149   Matrix<> Pst_1(ns,1);
150   Matrix<> unnorm_pstyn(ns,1);
151   Matrix<> cump(ns,1);
152   int t = 0;
153   int dur = 0;
154   double durprob, this_pd, prior_pd;
155   while (t < n) {
156     ch[t] = true;
157     if (t == 0) {
158       unnorm_pstyn = Bstar(t,_);
159     } else {
160       st = s(t-1);
161       unnorm_pstyn = Bstar(t,_) + Pbar(st-1,_);
162     }
163     pstyn = unnorm_pstyn - log(sum(exp(unnorm_pstyn - max(unnorm_pstyn)))) - max(unnorm_pstyn);
164     pstyn = exp(pstyn);
165     cump(0) = pstyn(0);
166     for (int j = 1; j < ns; j++) {
167       cump(j) = cump(j-1) + pstyn(j);
168     }
169     double UU = stream.runif();
170     s(t) = 1;
171     for (int j = 1; j < ns; j++) {
172       if (UU >= cump(j-1) && UU < cump(j)) {
173         s(t) = j+1;
174       }
175     }
176     if (t > 0) trans(st-1, s(t) - 1) += 1;
177 
178     durprob = stream.runif();
179     for (dur = 0; durprob > 0.0 && t+dur < n; dur++) {
180       prior_pd = pD(dur, s(t) - 1);
181       this_pd = prior_pd + sum(py(t, s(t) - 1, t+dur, s(t)-1));
182       this_pd += B(t+dur, s(t) - 1) - Bstar(t, s(t) - 1);
183       durprob -= exp(this_pd);
184       s(t+dur) = s(t);
185     }
186     durs[t] = dur;
187     nstate[s(t) - 1] += dur;
188 
189     t += dur;
190   }
191   Matrix<> s_norep = selif(s, durs > 0);
192   Matrix<> durs_norep = selif(durs, durs > 0);
193   hsmm_state result;
194   result.s = s;
195   result.s_norep = s_norep;
196   result.durs = durs_norep;
197   result.trans = trans;
198   result.nstate = nstate;
199   return result;
200 }
201 
202 
203 
204 
205 ////////////////////////////////////////////
206 // HDPHSMMnegbinRegChange implementation.
207 ////////////////////////////////////////////
208 template <typename RNGTYPE>
HDPHSMMnegbinReg_impl(rng<RNGTYPE> & stream,double * betaout,double * Pout,double * omegaout,double * sout,double * nuout,double * rhoout,double * tau1out,double * tau2out,int * comp1out,int * comp2out,double * sr1out,double * sr2out,double * mr1out,double * mr2out,double * gammaout,double * alphaout,double * rhosizes,const double * Ydata,const int * Yrow,const int * Ycol,const double * Xdata,const int * Xrow,const int * Xcol,const int * K,const int * burnin,const int * mcmc,const int * thin,const int * verbose,const double * betastart,const double * Pstart,const double * nustart,const double * rhostart,const double * tau1start,const double * tau2start,const double * component1start,const double * alphastart,const double * gammastart,const double * omegastart,const double * a_alpha,const double * b_alpha,const double * a_gamma,const double * b_gamma,const double * a_omega,const double * b_omega,const double * e,const double * f,const double * g,const double * r,const double * rhostepdata,const double * b0data,const double * B0data)209 void HDPHSMMnegbinReg_impl(rng<RNGTYPE>& stream,
210                            double *betaout,
211                            double *Pout,
212                            double *omegaout,
213                            double *sout,
214                            double *nuout,
215                            double *rhoout,
216                            double *tau1out,
217                            double *tau2out,
218                            int *comp1out,
219                            int *comp2out,
220                            double *sr1out,
221                            double *sr2out,
222                            double *mr1out,
223                            double *mr2out,
224                            double *gammaout,
225                            double *alphaout,
226                            double *rhosizes,
227                            const double *Ydata,
228                            const int *Yrow,
229                            const int *Ycol,
230                            const double *Xdata,
231                            const int *Xrow,
232                            const int *Xcol,
233                            const int *K,
234                            const int *burnin,
235                            const int *mcmc,
236                            const int *thin,
237                            const int *verbose,
238                            const double *betastart,
239                            const double *Pstart,
240                            const double *nustart,
241                            const double *rhostart,
242                            const double *tau1start,
243                            const double *tau2start,
244                            const double *component1start,
245                            const double *alphastart,
246                            const double *gammastart,
247                            const double *omegastart,
248                            const double *a_alpha,
249                            const double *b_alpha,
250                            const double *a_gamma,
251                            const double *b_gamma,
252                            const double *a_omega,
253                            const double *b_omega,
254                            const double *e,
255                            const double *f,
256                            const double *g,
257                            const double *r,
258                            const double *rhostepdata,
259                            const double *b0data,
260                            const double *B0data)
261 {
262   const Matrix <> Y(*Yrow, *Ycol, Ydata);
263   const Matrix <> X(*Xrow, *Xcol, Xdata);
264   const int tot_iter = *burnin + *mcmc;
265   const int nstore = *mcmc / *thin;
266   const int n = Y.rows();
267   const int k = X.cols();
268   const int ns = *K;
269   const int max_comp = 10;
270   const Matrix <> b0(k, 1, b0data);
271   const Matrix <> B0(k, k, B0data);
272   const Matrix <> B0inv = invpd(B0);
273   Matrix<> wr1(max_comp, 1);
274   Matrix<> mr1(max_comp, 1);
275   Matrix<> sr1(max_comp, 1);
276   Matrix<> wr2(n, max_comp);
277   Matrix<> mr2(n, max_comp);
278   Matrix<> sr2(n, max_comp);
279   Matrix<> nr2(n, 1);
280 
281   Matrix <> nu(n, 1, nustart);
282   Matrix <> rho(ns, 1, rhostart);
283   Matrix <> rho_slice(ns, 1, true, 0.0);
284   Matrix <> step_out(ns, 1, rhostepdata);
285   Matrix <> beta(ns, k, betastart);
286   Matrix <> tau1(n, 1, tau1start);
287   Matrix <> tau2(n, 1, tau2start);
288   Matrix <> component1(n, 1, component1start);
289   Matrix <> component2(n, 1);
290   Matrix <> P(ns, ns, Pstart);
291   Matrix <> omega(ns, 1, omegastart);
292   double gamma = *gammastart;
293   double alpha = *alphastart;
294   Matrix <> gamma_prime(ns, 1, true, gamma/ns);
295 
296   Matrix<> beta_store(nstore, ns*k);
297   Matrix<> P_store(nstore, ns*ns);
298   Matrix<> s_store(nstore, n);
299   Matrix<int> nstate(ns, 1);
300   Matrix<int> component1_store(nstore, n);
301   Matrix<int> component2_store(nstore, n);
302   Matrix<> tau1_store(nstore, n);
303   Matrix<> tau2_store(nstore, n);
304   Matrix<> sr1_store(nstore, n);
305   Matrix<> sr2_store(nstore, n);
306   Matrix<> mr1_store(nstore, n);
307   Matrix<> mr2_store(nstore, n);
308   Matrix<> nu_store(nstore, n);
309   Matrix<> rho_store(nstore, ns);
310   Matrix<> omega_store(nstore, ns);
311   Matrix<> gamma_store(nstore, 1);
312   Matrix<> alpha_store(nstore, 1);
313 
314   hsmm_state Sout;
315   Matrix<> s(n,1);
316   Matrix<> (n,ns);
317   Matrix<int> trans_counts(ns,ns);
318   Matrix<> V(ns,ns);
319   Matrix<> rhostep;
320 
321   init_aux(stream, Y, wr1, mr1, sr1, wr2, mr2, sr2, nr2, component2);
322   int nplus = 0;
323   for (int t = 0; t < n; t++) {
324     int yt = (int) Y[t];
325     if (yt > 0) {
326       nplus = nplus + 1;
327     }
328   }
329 
330   Matrix<> Xplus(nplus, k);
331   int xp_count = 0;
332   for (int t = 0; t < nplus; t++) {
333     int yt = (int) Y[t];
334     Matrix<> xt = X(t, _);
335     if (yt > 0) {
336       Xplus(xp_count, _) = xt;
337       xp_count++;
338     }
339   }
340 
341   //MCMC loop
342   int count = 0;
343   int debug = 0;
344   for (int iter = 0; iter < tot_iter; ++iter){
345     if (debug > 0) Rprintf("\niter: %i\n-----\n", iter);
346     if (debug > 0) Rprintf("before s \n");
347     //////////////////////
348     // 1. Sample s
349     //////////////////////
350     Sout = negbin_hdphsmm_reg_state_sampler(stream, ns, Y, X, beta, P, nu, rho, omega, *r);
351     s = Sout.s;
352     Matrix<> s_norep = Sout.s_norep;
353     Matrix<> durs = Sout.durs;
354     nstate = Sout.nstate;
355     trans_counts = Sout.trans;
356 
357     if (debug > 0) Rprintf("before rho \n");
358     //////////////////////
359     // 5. Sample rho
360     //////////////////////
361     // We need to evaluate the density in the first iteration for the slice sampler
362     for (int j = 0; j < ns; j++) {
363       if (nstate[j] > 0) {
364         Matrix<> yj = selif(Y, s == (j + 1));
365         Matrix<> xj = selif(X, s == (j + 1));
366         Matrix<> lam = exp(xj * ::t(beta(j,_)));
367         rhostep = rho_slice_sampler(stream, yj, lam, rho[j], step_out[j], *g, *e, *f);
368       } else {
369         rhostep = rho_prior_sampler(stream, rho[j], step_out[j], *g, *e, *f);
370       }
371       rho[j] = rhostep[0];
372       if (iter > 10) {
373         step_out[j] = (1/((double) iter)) * rhostep[2] + ((((double) iter)-1)/((double) iter)) * step_out[j];
374       }
375     }
376 
377     if (debug > 0) Rprintf("before nu \n");
378     //////////////////////
379     // 6. Sample nu
380     //////////////////////
381     Matrix<> lambda(n, 1);
382     for (int t = 0; t < n; t++) {
383       int st = (int) s[t];
384       int yt = (int) Y[t];
385       Matrix<> xt = X(t, _);
386       Matrix<> mu_t = exp(xt* ::t(beta(st-1,_)));
387       lambda[t] = mu_t[0];
388       nu[t] = stream.rgamma(rho[st-1] + yt, rho[st-1] + lambda[t]);
389     }
390 
391     if (debug > 0) Rprintf("before tau \n");
392     //////////////////////
393     // 4. Sample tau
394     //////////////////////
395     Matrix<> TAUout;
396     for (int t = 0; t < n; t++) {
397       int yt = (int) Y[t];
398       double mut = exp(log(nu[t]) + log(lambda[t]));
399 
400       TAUout = tau_comp_sampler(stream, yt, mut, wr1, mr1, sr1,
401                                 wr2(t,_), mr2(t,_), sr2(t,_), nr2[t]);
402       tau1[t] = TAUout[0];
403       tau2[t] = TAUout[1];
404       component1[t] = TAUout[2];
405       component2[t] = TAUout[3];
406     }
407 
408 
409     if (debug > 0) Rprintf("before beta \n");
410     //////////////////////
411     // 2. Sample beta
412     //////////////////////
413     Matrix<> y_tilde(n,1);
414     Matrix<> Sigma_inv_sum(n, 1);
415     Matrix<> yp_tilde(n,1);
416     Matrix<> Sigma_plus_inv_sum(n, 1);
417     Matrix<> sr1_hold(n,1);
418     Matrix<> sr2_hold(n,1);
419     Matrix<> mr1_hold(n,1);
420     Matrix<> mr2_hold(n,1);
421     for (int t = 0; t<n ; ++t) {
422       int yt = (int) Y[t];
423       int comp1 = (int) component1[t];
424         if (yt > 0) {
425           int comp2 = (int) component2[t];
426           sr2_hold[t] = sr2(t, comp2 - 1);
427           mr2_hold[t] = mr2(t, comp2 - 1);
428           Sigma_plus_inv_sum[t] = 1/sqrt(sr2(t, comp2 - 1));
429           yp_tilde[t] = (-log(tau2[t]) - log(nu[t]) - mr2(t, comp2 - 1))/sqrt(sr2(t, comp2-1));
430         }
431         sr1_hold[t] = sr1[comp1 - 1];
432         mr1_hold[t] = mr1[comp1 - 1];
433         Sigma_inv_sum[t] = 1/sqrt(sr1[comp1 - 1]);
434         y_tilde[t] = (-log(tau1[t]) - log(nu[t])  - mr1[comp1 - 1])/sqrt(sr1[comp1 - 1]);
435     }
436 
437     int beta_count = 0;
438     int pcount = 0;
439     Matrix<int> pnstate(ns, 1);
440 
441     for (int j = 0; j <ns ; ++j){
442       for (int i = 0; i<n; ++i){
443 	if (s[i] == (j+1)) {
444           int yt = (int) Y[i];
445           if (yt > 0) {
446             pnstate[j] = pnstate[j] + 1;
447           }
448 	}
449       }
450       beta_count = beta_count + nstate[j];
451       pcount = pcount + pnstate[j];
452       if (nstate[j] == 0) {
453         if (k == 1) {
454           beta(j, _) = stream.rnorm(b0[0], sqrt(1/B0[0]));
455         } else {
456           beta(j,_) = stream.rmvnorm(b0, invpd(B0));
457         }
458       } else {
459         int tot_rows = nstate[j] + pnstate[j];
460         Matrix<> yjp(tot_rows, 1);
461         Matrix<> Xjp(tot_rows, k);
462         Matrix<> wip(tot_rows, 1);
463 
464         yjp(0, 0, nstate[j] - 1, 0) = selif(y_tilde, s == (j + 1));
465         Xjp(0, 0, nstate[j] - 1, k-1) = selif(X, s == (j + 1));
466         wip(0, 0, nstate[j] - 1, 0) = selif(Sigma_inv_sum, s == (j + 1));
467         if (pnstate[j] > 0) {
468           yjp(nstate[j], 0, tot_rows - 1, 0) = selif(yp_tilde, (s == (j+1)) & (Y > 0));
469           Xjp(nstate[j], 0, tot_rows - 1, k - 1) = selif(X, (s == (j + 1)) & (Y > 0));
470           wip(nstate[j], 0, tot_rows - 1, 0) = selif(Sigma_plus_inv_sum, (s == (j + 1)) & (Y > 0));
471         }
472 
473         Matrix<> Xwj(Xjp.rows(), k);
474         for (unsigned int h = 0; h<Xjp.rows(); ++h){
475           Xwj(h, _) = Xjp(h,_)*wip[h];
476         }
477 
478         Matrix<> Bn = invpd(B0 + ::t(Xwj)*Xwj);
479         Matrix<> bn = Bn*gaxpy(B0, b0, ::t(Xwj)*yjp);
480 
481         if (k == 1) {
482           beta(j, _) = stream.rnorm(bn[0], sqrt(Bn[0]));
483         } else {
484           beta(j,_) = stream.rmvnorm(bn, Bn);
485         }
486       }
487     }
488 
489 
490 
491 
492     if (debug > 0) Rprintf("before DA for P \n");
493     // Augmented variables to make sampling P easier
494     Matrix<int> froms = sumc(::t(Sout.trans));
495     for (int j = 0; j < ns; j++) {
496       if (debug > 1) Rprintf("j = %i, froms = %i P = %10.5f tc = %i\n", j, froms[j], P(j,j), trans_counts(j,j));
497       int draw;
498       for (int nj = 0; nj < froms[j]; nj++) {
499         draw = std::floor(stream.rexp(-log(P(j,j))));
500         // sometimes we get wonky draws here
501         if (draw > 0) trans_counts(j,j) += draw;
502       }
503       if (debug > 1) Rprintf("DA tc[j,j] = %i\n", trans_counts(j,j));
504     }
505 
506     //////////////////////
507     // 3. Sample dish counts
508     //////////////////////
509     if (debug > 0) Rprintf("before dish counts \n");
510     Matrix<int> rest_dishes(ns, ns);
511     int n_jk;
512     int m_jk;
513     double m_num;
514     for (int j = 0; j < ns; ++j) {
515       for (int k = 0; k < ns; ++k) {
516         n_jk = trans_counts(j,k);
517         if (n_jk == 0) {
518           rest_dishes(j,k) = 0;
519         } else {
520           m_jk = 0;
521           for (int h = 0; h < n_jk; ++h) {
522             m_num = alpha * gamma_prime(k);
523             if (stream.runif() < (m_num/(m_num + h))) {
524               m_jk++;
525             }
526           }
527           rest_dishes(j,k) = m_jk;
528         }
529       }
530     }
531 
532     //////////////////////
533     // 3. Sample P
534     //////////////////////
535     if (debug > 0) Rprintf("before concentration params \n");
536     // Gamma prime is beta in all HDP-H(S)MM notation
537     Matrix<int> Nkdot = ::t(sumc(::t(trans_counts)));
538     Matrix<int> Mkdot = ::t(sumc(::t(rest_dishes)));
539     Matrix<int> Mdotk = ::t(sumc(rest_dishes));
540     double alpha0 = alpha;
541     double gamma0 = gamma;
542     int Mbar_tot = sum(rest_dishes);
543 
544     if (debug > 1) {
545       for (int j = 0; j < ns; j++) {
546         Rprintf("Nkdot[%i] = %i  Mkdot[%i] = %i  Mdotk[%i] = %i  g_prime[%i] = %10.5f\n", j, Nkdot[j], j, Mkdot[j], j, Mdotk[j], j, gamma_prime[j]);
547       }
548     }
549     // only sample these if the parameters make sense
550     if (debug > 0) Rprintf("before alpha \n");
551     if (*a_alpha > 0 && *b_alpha > 0) {
552       alpha = sample_conparam(stream, alpha0, Nkdot, sum(Mkdot), *a_alpha, *b_alpha, 50);
553     }
554     if (debug > 0) Rprintf("alpha = %10.5f\nbefore gamma\n", alpha);
555     if (*a_gamma > 0 && *b_gamma > 0) {
556       gamma = sample_conparam(stream, gamma0, Mbar_tot, ns, *a_gamma, *b_gamma, 50);
557     }
558     if (debug > 0) Rprintf("gamma = %10.5f\nbefore gamma prime\n", gamma);
559     Matrix<double> gamma_prime_dir(ns,1);
560     for (int j = 0; j < ns; j++) {
561       gamma_prime_dir[j] = gamma/ns + Mdotk[j];
562     }
563     if (min(gamma_prime_dir + 1) <= 1) gamma_prime_dir += std::numeric_limits<double>::epsilon();
564 
565     gamma_prime = stream.rdirich(gamma_prime_dir);
566     while (std::isnan(gamma_prime[0])) {
567       gamma_prime = stream.rdirich(gamma_prime_dir);
568     }
569 
570     if (debug > 0) Rprintf("before P\n", alpha);
571     for (int j = 0; j < ns; ++j) {
572       Matrix<double> p_dirich_params(ns,1);
573       for (int i = 0; i < ns; ++i) {
574         p_dirich_params(i) = alpha * gamma_prime(i) + trans_counts(j,i);
575       }
576       if (min(p_dirich_params + 1) <= 1) p_dirich_params += std::numeric_limits<double>::epsilon();
577       P(j,_) = ::t(stream.rdirich(p_dirich_params));
578 
579       // the rdirichlet function has numerical stability problems so sometimes we have to redraw
580       while (std::isnan(P(j,0))) {
581         if (debug > 0) Rprintf("\nRedrawing P due to numerical problems (iter %i)...\n", iter);
582         P(j,_) = ::t(stream.rdirich(p_dirich_params));
583       }
584       // again, due to numerical issues we can get numerically close to 1, just use the mean instead
585       if (P(j,j) == 1.0) {
586         P(j,_) = p_dirich_params/sum(p_dirich_params);
587       }
588 
589     }
590 
591     if (debug > 0) Rprintf("before durs \n");
592     // //////////////////////
593     // // 3. Sample duration parameters
594     // //////////////////////
595     int last_state = s(n-1) - 1;
596     int cen_dur = durs[durs.rows()-1];
597 
598     for (int j = 0; j < ns; j++) {
599       int times_visited = sum(s_norep == (j + 1));
600       int uncen_total_dur = nstate[j];
601 
602       // augment the last censored duration with acceptance sampling
603       if (j == last_state) {
604         int uncen_dur = 0;
605         double tail_prob;
606         tail_prob = log(1.0 - pnbinom(cen_dur, *r, omega[j]));
607         if (exp(tail_prob) > 0.1) {
608            while (uncen_dur < cen_dur) {
609             uncen_dur = stream.rnbinom(*r, omega[j]);
610             R_CheckUserInterrupt();
611           }
612         } else {
613           if (debug > 0) Rprintf("direct sampling \n");
614           double u = stream.runif();
615           uncen_dur = cen_dur;
616           while (u > 0) {
617             u -= exp(log(dnbinom(uncen_dur, *r, omega[j])) - tail_prob);
618             uncen_dur++;
619           }
620         }
621         uncen_total_dur += uncen_dur - cen_dur;
622       }
623       omega[j] = stream.rbeta(*a_omega + *r * times_visited, *b_omega + uncen_total_dur);
624     }
625 
626 
627 
628     if (iter >= *burnin && ((iter % *thin)==0)){
629       Matrix<> tbeta = ::t(beta);
630       for (int i=0; i<(ns*k); ++i){
631 	beta_store(count,i) = tbeta[i];
632       }
633       for (int j=0; j<ns*ns; ++j){
634 	P_store(count,j)= P[j];
635       }
636       s_store(count,_) = s(_, 0);
637       tau1_store(count,_) = tau1(_, 0);
638       tau2_store(count,_) = tau2(_, 0);
639       component1_store(count,_) = component1(_, 0);
640       component2_store(count,_) = TAUout(_, 3);
641       sr1_store(count,_) = sr1_hold;
642       sr2_store(count,_) = sr2_hold;
643       mr1_store(count,_) = mr1_hold;
644       mr2_store(count,_) = mr2_hold;
645       nu_store(count,_) = nu(_, 0);
646       rho_store(count, _) = rho(_, 0);
647       omega_store(count, _) = omega(_, 0);
648       gamma_store(count) = gamma;
649       alpha_store(count) = alpha;
650       ++count;
651     }
652 
653     if(*verbose > 0 && iter % *verbose == 0){
654       Rprintf("\n\n HDPHSMMnegbinChange iteration %i of %i \n\n", (iter+1), tot_iter);
655       for (int j = 0;j<ns; ++j){
656 	Rprintf("The number of observations in state %i is %10.5f\n", j+1, static_cast<double>(nstate[j]));
657       }
658       for (int i = 0; i<ns; ++i){
659         if (nstate[i] > 0) {
660           Rprintf("rho in state %i is %10.5f\n", i+1, rho[i]);
661           for (int j = 0; j<k; ++j){
662             Rprintf("beta(%i) in state %i is %10.5f\n", j+1, i+1, beta(i, j));
663           }
664         }
665       }
666     }
667 
668   }// end MCMC loop
669 
670 
671 
672   R_CheckUserInterrupt();
673 
674   for (int i = 0; i<(nstore*ns*k); ++i){
675     betaout[i] = beta_store[i];
676   }
677   for (int i = 0; i<(nstore*ns*ns); ++i){
678     Pout[i] = P_store[i];
679   }
680   for (int i = 0; i<(nstore*n); ++i){
681     sout[i] = s_store[i];
682     nuout[i] = nu_store[i];
683     tau1out[i] = tau1_store[i];
684     tau2out[i] = tau2_store[i];
685     comp1out[i] = component1_store[i];
686     comp2out[i] = component2_store[i];
687     sr1out[i] = sr1_store[i];
688     sr2out[i] = sr2_store[i];
689     mr1out[i] = mr1_store[i];
690     mr2out[i] = mr2_store[i];
691 
692   }
693   for (int i = 0; i<(nstore*ns); ++i){
694     rhoout[i] = rho_store[i];
695     omegaout[i] = omega_store[i];
696   }
697   for (int i = 0; i<ns; ++i){
698     rhosizes[i] = step_out[i];
699   }
700   for (int i = 0; i < nstore; ++i) {
701     gammaout[i] = gamma_store(i);
702     alphaout[i] = alpha_store(i);
703   }
704 }
705 
706 extern "C" {
cHDPHSMMnegbin(double * betaout,double * Pout,double * omegaout,double * sout,double * nuout,double * rhoout,double * tau1out,double * tau2out,int * comp1out,int * comp2out,double * sr1out,double * sr2out,double * mr1out,double * mr2out,double * gammaout,double * alphaout,double * rhosizes,const double * Ydata,const int * Yrow,const int * Ycol,const double * Xdata,const int * Xrow,const int * Xcol,const int * K,const int * burnin,const int * mcmc,const int * thin,const int * verbose,const double * betastart,const double * Pstart,const double * nustart,const double * rhostart,const double * tau1start,const double * tau2start,const double * component1start,const double * alphastart,const double * gammastart,const double * omegastart,const double * a_alpha,const double * b_alpha,const double * a_gamma,const double * b_gamma,const double * a_omega,const double * b_omega,const double * e,const double * f,const double * g,const double * r,const double * rhostepdata,const int * uselecuyer,const int * seedarray,const int * lecuyerstream,const double * b0data,const double * B0data)707   void cHDPHSMMnegbin(double *betaout,
708                       double *Pout,
709                       double *omegaout,
710                       double *sout,
711                       double *nuout,
712                       double *rhoout,
713                       double *tau1out,
714                       double *tau2out,
715                       int *comp1out,
716                       int *comp2out,
717                       double *sr1out,
718                       double *sr2out,
719                       double *mr1out,
720                       double *mr2out,
721                       double *gammaout,
722                       double *alphaout,
723                       double *rhosizes,
724                       const double *Ydata,
725                       const int *Yrow,
726                       const int *Ycol,
727                       const double *Xdata,
728                       const int *Xrow,
729                       const int *Xcol,
730                       const int *K,
731                       const int *burnin,
732                       const int *mcmc,
733                       const int *thin,
734                       const int *verbose,
735                       const double *betastart,
736                       const double *Pstart,
737                       const double *nustart,
738                       const double *rhostart,
739                       const double *tau1start,
740                       const double *tau2start,
741                       const double *component1start,
742                       const double *alphastart,
743                       const double *gammastart,
744                       const double *omegastart,
745                       const double *a_alpha,
746                       const double *b_alpha,
747                       const double *a_gamma,
748                       const double *b_gamma,
749                       const double *a_omega,
750                       const double *b_omega,
751                       const double *e,
752                       const double *f,
753                       const double *g,
754                       const double *r,
755                       const double *rhostepdata,
756                       const int* uselecuyer,
757                       const int* seedarray,
758                       const int* lecuyerstream,
759                       const double *b0data,
760                       const double *B0data) {
761 
762     MCMCPACK_PASSRNG2MODEL(HDPHSMMnegbinReg_impl,
763                            betaout, Pout, omegaout, sout, nuout, rhoout,
764                            tau1out, tau2out, comp1out, comp2out,
765                            sr1out, sr2out, mr1out, mr2out,
766                            gammaout, alphaout, rhosizes,
767                            Ydata, Yrow, Ycol,
768                            Xdata, Xrow, Xcol,
769                            K, burnin, mcmc, thin, verbose,
770                            betastart, Pstart, nustart, rhostart,
771                            tau1start, tau2start, component1start,
772                            alphastart, gammastart, omegastart,
773                            a_alpha, b_alpha, a_gamma, b_gamma,
774                            a_omega, b_omega, e, f, g, r, rhostepdata,
775                            b0data, B0data);
776   }//end MCMC
777 } // end extern "C"
778 
779 
780 #endif
781