1 /* Mixture Dirichlet densities
2  *
3  * Contents:
4  *   1. <ESL_MIXDCHLET> object
5  *   2. Likelihoods, posteriors, inference
6  *   3. Maximum likelihood fitting to count data
7  *   4. Reading/writing mixture Dirichlet files
8  *   5. Debugging and development tools
9  *   6. Unit tests
10  *   7. Test driver
11  *   8. Example
12  *
13  * See also:
14  *   esl_dirichlet : simple Dirichlet densities
15  *   esl-mixdchlet miniapp : fitting and more
16  */
17 #include "esl_config.h"
18 
19 #include <stdlib.h>
20 #include <stdio.h>
21 
22 #include "easel.h"
23 #include "esl_dirichlet.h"
24 #include "esl_fileparser.h"
25 #include "esl_graph.h"
26 #include "esl_matrixops.h"
27 #include "esl_minimizer.h"
28 #include "esl_random.h"
29 #include "esl_stats.h"
30 #include "esl_vectorops.h"
31 #include "esl_mixdchlet.h"
32 
33 
34 /*****************************************************************
35  *# 1. <ESL_MIXDCHLET> object
36  *****************************************************************/
37 
38 /* Function:  esl_mixdchlet_Create()
39  *
40  * Purpose:   Create a new mixture Dirichlet prior with <Q> components,
41  *            each with <K> parameters.
42  *
43  * Returns:   ptr to new <ESL_MIXDCHLET> on success.
44  *
45  * Throws:    NULL on allocation failure.
46  */
47 ESL_MIXDCHLET *
esl_mixdchlet_Create(int Q,int K)48 esl_mixdchlet_Create(int Q, int K)
49 {
50   ESL_MIXDCHLET *dchl = NULL;
51   int            status;
52 
53   ESL_DASSERT1( (Q > 0) );
54   ESL_DASSERT1( (K > 0) );
55 
56   ESL_ALLOC(dchl, sizeof(ESL_MIXDCHLET));
57   dchl->q      = NULL;
58   dchl->alpha  = NULL;
59   dchl->postq  = NULL;
60 
61   ESL_ALLOC(dchl->q,      sizeof(double)   * Q);
62   ESL_ALLOC(dchl->postq,  sizeof(double)   * Q);
63   if ((dchl->alpha = esl_mat_DCreate(Q,K)) == NULL) goto ERROR;
64 
65   dchl->Q = Q;
66   dchl->K = K;
67   return dchl;
68 
69  ERROR:
70   esl_mixdchlet_Destroy(dchl);
71   return NULL;
72 }
73 
74 
75 
76 /* Function:  esl_mixdchlet_Destroy()
77  * Synopsis:  Free a mixture Dirichlet.
78  */
79 void
esl_mixdchlet_Destroy(ESL_MIXDCHLET * dchl)80 esl_mixdchlet_Destroy(ESL_MIXDCHLET *dchl)
81 {
82   if (dchl)
83     {
84       free(dchl->q);
85       esl_mat_DDestroy(dchl->alpha);
86       free(dchl->postq);
87       free(dchl);
88     }
89 }
90 
91 
92 
93 /*****************************************************************
94  * 2. Likelihoods, posteriors, inference
95  *****************************************************************/
96 
97 
98 /* mixchlet_postq()
99  * Calculate P(q | c), the posterior probability of component q.
100  */
101 static void
mixdchlet_postq(ESL_MIXDCHLET * dchl,double * c)102 mixdchlet_postq(ESL_MIXDCHLET *dchl, double *c)
103 {
104   int k;
105   for (k = 0; k < dchl->Q; k++)
106     if (dchl->q[k] > 0.) dchl->postq[k] = log(dchl->q[k]) + esl_dirichlet_logpdf_c(c, dchl->alpha[k], dchl->K);
107     else                 dchl->postq[k] = -eslINFINITY;
108   esl_vec_DLogNorm(dchl->postq, dchl->Q);
109 }
110 
111 
112 
113 /* Function:  esl_mixdchlet_logp_c()
114  *
115  * Purpose:   Given observed count vector $c[0..K-1]$ and a mixture
116  *            Dirichlet <dchl>, calculate $\log P(c \mid \theta)$.
117  *
118  * Args:      dchl     : mixture Dirichlet
119  *            c        : count vector, [0..K-1]
120  *
121  * Returns:   $\log P(c \mid \theta)$
122  *
123  * Note:      Only because a workspace in <dchl> is used, you can't
124  *            declare <dchl> to be const.
125  */
126 double
esl_mixdchlet_logp_c(ESL_MIXDCHLET * dchl,double * c)127 esl_mixdchlet_logp_c(ESL_MIXDCHLET *dchl, double *c)
128 {
129   int k;
130   for (k = 0; k < dchl->Q; k++)
131     if (dchl->q[k] > 0.) dchl->postq[k] = log(dchl->q[k]) + esl_dirichlet_logpdf_c(c, dchl->alpha[k], dchl->K);
132     else                 dchl->postq[k] = -eslINFINITY;
133   return esl_vec_DLogSum(dchl->postq, dchl->Q);
134 }
135 
136 
137 
138 /* Function:  esl_mixdchlet_MPParameters()
139  * Synopsis:  Calculate mean posterior parameters from a count vector
140  *
141  * Purpose:   Given a mixture Dirichlet prior <dchl> and observed
142  *            countvector <c> of length <dchl->K>, calculate mean
143  *            posterior parameter estimates <p>. Caller provides the
144  *            storage <p>, allocated for at least <dchl->K> parameters.
145  *
146  * Returns:   <eslOK> on success, and <p> contains mean posterior
147  *            probability parameter estimates.
148  *
149  * Note:      Only because a workspace in <dchl> is used, you can't
150  *            declare <dchl> to be const.
151  */
152 int
esl_mixdchlet_MPParameters(ESL_MIXDCHLET * dchl,double * c,double * p)153 esl_mixdchlet_MPParameters(ESL_MIXDCHLET *dchl, double *c, double *p)
154 {
155   int    k,a;		// indices over components, residues
156   double totc;
157   double totalpha;
158 
159   /* Calculate posterior prob P(k | c) of each component k given count vector c. */
160   mixdchlet_postq(dchl, c);
161 
162   /* Compute mean posterior estimates for probability parameters */
163   totc = esl_vec_DSum(c, dchl->K);
164   esl_vec_DSet(p, dchl->K, 0.);
165   for (k = 0; k < dchl->Q; k++)
166     {
167       totalpha = esl_vec_DSum(dchl->alpha[k], dchl->K);
168       for (a = 0; a < dchl->K; a++)
169 	p[a] += dchl->postq[k] * (c[a] + dchl->alpha[k][a]) / (totc + totalpha);
170     }
171   /* should be normalized already, but for good measure: */
172   esl_vec_DNorm(p, dchl->K);
173   return eslOK;
174 }
175 
176 
177 
178 /*****************************************************************
179  * 3. Maximum likelihood fitting to count data
180  *****************************************************************/
181 /* This structure is used to shuttle the data into minimizer's generic
182  * (void *) API for all aux data
183  */
184 struct mixdchlet_data {
185   ESL_MIXDCHLET  *dchl;   /* dirichlet mixture parameters */
186   double        **c;      /* count vector array [0..N-1][0..K-1] */
187   int             N;      /* number of countvectors */
188 };
189 
190 /*****************************************************************
191  * Parameter vector packing/unpacking
192  *
193  * The Easel conjugate gradient code is a general optimizer. It takes
194  * a single parameter vector <p>, where the values are unconstrained
195  * real numbers.
196  *
197  * We're optimizing a mixture Dirichlet with two kinds of parameters.
198  * q[k] are mixture coefficients, constrained to be >= 0 and \sum_k
199  * q[k] = 1.  alpha[k][a] are the Dirichlet parameters for component
200  * k, constrained to be > 0.
201  *
202  * So we use a c.o.v. to get the coefficients and parameters in terms
203  * of unconstrained reals lambda and beta:
204  *   mixture coefficients:      q_k  = exp(lambda_k) / \sum_j exp(lambda_j)
205  *   Dirichlet parameters:   alpha_a = exp(beta_a)
206  *
207  * And we pack them all in one parameter vector, lambdas first:
208  * [0 ... Q-1] [0 ... K-1] [0 ... K-1]  ...
209  *    lambda's   beta_0      beta_1     ...
210  *
211  * The parameter vector p therefore has length Q(K+1), and is accessed as:
212  *    mixture coefficient lambda[k] is at p[k]
213  *    Dirichlet param beta[k][a] is at p[Q + q*K + a].
214  */
215 static void
mixdchlet_pack_paramvector(ESL_MIXDCHLET * dchl,double * p)216 mixdchlet_pack_paramvector(ESL_MIXDCHLET *dchl, double *p)
217 {
218   int j = 0;     /* counter in packed parameter vector <p> */
219   int k,a;	 /* indices over components, residues */
220 
221   /* mixture coefficients */
222   for (k = 0; k < dchl->Q; k++)
223     p[j++] = log(dchl->q[k]);
224 
225   /* Dirichlet parameters */
226   for (k = 0; k < dchl->Q; k++)
227     for (a = 0; a < dchl->K; a++)
228       p[j++] = log(dchl->alpha[k][a]);
229 
230   ESL_DASSERT1(( j == dchl->Q *  (dchl->K + 1)) );
231 }
232 
233 /* Same as above but in reverse: given parameter vector <p>,
234  * do appropriate c.o.v. back to desired parameter space, and
235  * storing in mixdchlet <dchl>.
236  */
237 static void
mixdchlet_unpack_paramvector(double * p,ESL_MIXDCHLET * dchl)238 mixdchlet_unpack_paramvector(double *p, ESL_MIXDCHLET *dchl)
239 {
240   int j = 0;     /* counter in packed parameter vector <p> */
241   int k,a;	 /* indices over components, residues */
242 
243   /* mixture coefficients */
244   for (k = 0; k < dchl->Q; k++)
245     dchl->q[k] = exp(p[j++]);
246   esl_vec_DNorm(dchl->q, dchl->Q);
247 
248   /* Dirichlet parameters */
249   for (k = 0; k < dchl->Q; k++)
250     for (a = 0; a < dchl->K; a++)
251       dchl->alpha[k][a] = exp(p[j++]);
252 
253   ESL_DASSERT1(( j == dchl->Q *  (dchl->K + 1)) );
254 }
255 
256 /* The negative log likelihood function to be minimized by ML fitting. */
257 static double
mixdchlet_nll(double * p,int np,void * dptr)258 mixdchlet_nll(double *p, int np, void *dptr)
259 {
260   ESL_UNUSED(np);  // parameter number <np> must be an arg, dictated by conj gradient API.
261   struct mixdchlet_data *data = (struct mixdchlet_data *) dptr;
262   ESL_MIXDCHLET         *dchl = data->dchl;
263   double                 nll  = 0.;
264   int  i;
265 
266   mixdchlet_unpack_paramvector(p, dchl);
267   for (i = 0; i < data->N; i++)
268     nll -= esl_mixdchlet_logp_c(dchl, data->c[i]);
269   return nll;
270 }
271 
272 /* The gradient of the NLL w.r.t. each free parameter in p. */
273 static void
mixdchlet_gradient(double * p,int np,void * dptr,double * dp)274 mixdchlet_gradient(double *p, int np, void *dptr, double *dp)
275 {
276   struct mixdchlet_data *data = (struct mixdchlet_data *) dptr;
277   ESL_MIXDCHLET         *dchl = data->dchl;
278   double  sum_alpha;     //  |alpha_k|
279   double  sum_c;         //  |c_i|
280   double  psi1;          //  \Psi(c_ia + alpha_ka)
281   double  psi2;          //  \Psi( |c_i| + |alpha_k| )
282   double  psi3;          //  \Psi( |alpha_k| )
283   double  psi4;          //  \Psi( alpha_ka )
284   int     i,j,k,a;	 // indices over countvectors, unconstrained parameters, components, residues
285 
286   mixdchlet_unpack_paramvector(p, dchl);
287   esl_vec_DSet(dp, np, 0.);
288   for (i = 0; i < data->N; i++)
289     {
290       mixdchlet_postq(dchl, data->c[i]);           // d->postq[q] is now P(q | c_i, theta)
291       sum_c = esl_vec_DSum(data->c[i], dchl->K);   // |c_i|
292 
293       /* mixture coefficient gradient */
294       j = 0;
295       for (k = 0; k < dchl->Q; k++)
296 	dp[j++] -= dchl->postq[k] - dchl->q[k];
297 
298       for (k = 0; k < dchl->Q; k++)
299 	{
300 	  sum_alpha = esl_vec_DSum(dchl->alpha[k], dchl->K);
301 	  esl_stats_Psi( sum_alpha + sum_c, &psi2);
302 	  esl_stats_Psi( sum_alpha,         &psi3);
303 	  for (a = 0; a < dchl->K; a++)
304 	    {
305 	      esl_stats_Psi( dchl->alpha[k][a] + data->c[i][a], &psi1);
306 	      esl_stats_Psi( dchl->alpha[k][a],                 &psi4);
307 	      dp[j++] -= dchl->alpha[k][a] * dchl->postq[k] * (psi1 - psi2 + psi3 - psi4);
308 	    }
309 	}
310     }
311  }
312 
313 /* Function:  esl_mixdchlet_Fit()
314  *
315  * Purpose:   Given many count vectors <c> (<N> of them) and an initial
316  *            guess <dchl> for a mixture Dirichlet, find maximum likelihood
317  *            parameters by conjugate gradient descent optimization,
318  *            updating <dchl>. Optionally, return the final negative log likelihood
319  *            in <*opt_nll>.
320  *
321  * Args:      c       : count vectors c[0..N-1][0..K-1]
322  *            N       : number of count vectors; N>0
323  *            dchl    : initial guess, updated to the fitted model upon return.
324  *            opt_nll : OPTIONAL: final negative log likelihood
325  *
326  * Returns:   <eslOK> on success, <dchl> contains the fitted
327  *            mixture Dirichlet, and <*opt_nll> (if passed) contains the final NLL.
328  *
329  *            <eslENOHALT> if the fit fails to converge in a reasonable
330  *            number of iterations (in <esl_min_ConjugateGradientDescent()>,
331  *            default <max_iterations> is currently 100), but an answer
332  *            is still in <dchl> and <*opt_nll>.
333  *
334  *            <eslEINVAL> if N < 1. (Caller needs to provide data, but
335  *            it's possible that an input file might not contain any,
336  *            and we want to make sure to bark about it.)
337  *
338  * Throws:    <eslEMEM> on allocation error, <dchl> is left in
339  *            in its initial state, and <*opt_nll> (if passed) is -inf.
340  *
341  *            <eslECORRUPT> if <dchl> isn't a valid mixture Dirichlet.
342  */
343 int
esl_mixdchlet_Fit(double ** c,int N,ESL_MIXDCHLET * dchl,double * opt_nll)344 esl_mixdchlet_Fit(double **c, int N, ESL_MIXDCHLET *dchl, double *opt_nll)
345 {
346   ESL_MIN_CFG *cfg = NULL;
347   ESL_MIN_DAT *dat = NULL;
348   struct mixdchlet_data data;
349   double *p        = NULL;  // parameter vector [0..nparam-1], for CG descent
350   int     nparam   =  dchl->Q * (dchl->K + 1);
351   double  fx;
352   int     status;
353 
354   cfg = esl_min_cfg_Create(nparam);
355   if (! cfg) { status = eslEMEM; goto ERROR; }
356   cfg->cg_rtol    = 3e-5;
357   cfg->brent_rtol = 1e-2;
358   esl_vec_DSet(cfg->u, nparam, 0.1);
359 
360   dat = esl_min_dat_Create(cfg);
361 
362   if (N < 1) return eslEINVAL;
363 #if (eslDEBUGLEVEL >= 1)
364   if ( esl_mixdchlet_Validate(dchl, NULL) != eslOK) ESL_EXCEPTION(eslECORRUPT, "initial mixture is invalid");
365 #endif
366   ESL_ALLOC(p,   sizeof(double) * nparam);
367 
368   /* <data> is a wrapper that shuttles count data, theta into CG  */
369   data.dchl = dchl;
370   data.c    = c;
371   data.N    = N;
372 
373   /* initialize <p> */
374   mixdchlet_pack_paramvector(dchl, p);
375 
376   /* Feed it all to the mighty optimizer */
377   status = esl_min_ConjugateGradientDescent(cfg, p, nparam,
378 					    &mixdchlet_nll,
379 					    &mixdchlet_gradient,
380 					    (void *) (&data), &fx, dat);
381   if      (status != eslENOHALT && status != eslOK) goto ERROR; // too many iterations? treat it as "good enough".
382 
383   /* Convert the final parameter vector back */
384   mixdchlet_unpack_paramvector(p, dchl);
385 
386   esl_min_dat_Dump(stdout, dat);
387 
388   free(p);
389   esl_min_cfg_Destroy(cfg);
390   esl_min_dat_Destroy(dat);
391   if (opt_nll) *opt_nll = fx;
392   return eslOK;
393 
394  ERROR:
395   free(p);
396   esl_min_cfg_Destroy(cfg);
397   esl_min_dat_Destroy(dat);
398   if (opt_nll) *opt_nll = -eslINFINITY;
399   return status;
400 }
401 
402 
403 /* Function:  esl_mixdchlet_Sample()
404  * Synopsis:  Sample a random (perhaps initial) ESL_MIXDCHLET
405  * Incept:    SRE, Sun 01 Jul 2018 [Hamilton]
406  *
407  * Purpose:   Use random number generator <rng> to sample a
408  *            <ESL_MIXDCHLET> that's already been created for
409  *            <dchl->Q> components and alphabet size <dchl->K>.  The
410  *            random Dirichlet parameters are sampled uniformly on a
411  *            (0,2) open interval, and the mixture coefficients are
412  *            sampled uniformly.
413  *
414  * Returns:   <eslOK> on success, and <dchl> contains the sampled
415  *            model.
416  */
417 int
esl_mixdchlet_Sample(ESL_RANDOMNESS * rng,ESL_MIXDCHLET * dchl)418 esl_mixdchlet_Sample(ESL_RANDOMNESS *rng, ESL_MIXDCHLET *dchl)
419 {
420   int k,a;
421 
422   esl_dirichlet_DSampleUniform(rng, dchl->Q, dchl->q);
423   for (k = 0; k < dchl->Q; k++)
424     for (a = 0; a < dchl->K; a++)
425       dchl->alpha[k][a] = 2.0 * esl_rnd_UniformPositive(rng);
426   return eslOK;
427 }
428 
429 
430 /*****************************************************************
431  *# 4. Reading/writing mixture Dirichlet files
432  *****************************************************************/
433 
434 /* Function:  esl_mixdchlet_Read()
435  *
436  * Purpose:   Reads a mixture Dirichlet from an open stream <efp>, using the
437  *            <ESL_FILEPARSER> token-based parser.
438  *
439  *            The first two tokens are <K>, the length of the Dirichlet parameter
440  *            vector(s), and <Q>, the number of mixture components. Then for
441  *            each of the <Q> mixture components <k>, it reads a mixture coefficient
442  *            <q[k]> followed by <K> Dirichlet parameters <alpha[k][a=0..K-1]>.
443  *
444  *            This function may be called more than once on the same open file,
445  *            to read multiple different mixture Dirichlets from it (transitions,
446  *            match emissions, insert emissions, for example).
447  *
448  * Note:      One reason this function takes an ESL_FILEPARSER instead of
449  *            a filename or an open FILE pointer is that file format errors
450  *            in Easel are non-fatal "normal" errors, and we want to record
451  *            an informative error message. The ESL_FILEPARSER has an error
452  *            buffer for this purpose.
453  *
454  * Returns:   <eslOK> on success, and <ret_dchl> contains a new <ESL_MIXDCHLET> object
455  *            that the caller is responsible for free'ing.
456  *
457  *            <eslEFORMAT> on 'normal' parse failure, in which case <efp->errbuf>
458  *            contains an informative diagnostic message, and <efp->linenumber>
459  *            contains the linenumber at which the parse failed.
460  */
461 int
esl_mixdchlet_Read(ESL_FILEPARSER * efp,ESL_MIXDCHLET ** ret_dchl)462 esl_mixdchlet_Read(ESL_FILEPARSER *efp,  ESL_MIXDCHLET **ret_dchl)
463 {
464   ESL_MIXDCHLET *dchl = NULL;
465   int   Q,K;			/* number of components, alphabet size */
466   char *tok;			/* ptr to a whitespace-delim, noncomment token */
467   int   toklen;			/* length of a parsed token */
468   int   k,a;			/* index over components, symbols */
469   int   status;
470 
471   if ((status = esl_fileparser_GetToken(efp, &tok, &toklen)) != eslOK) goto ERROR;
472   K = atoi(tok);
473   if (K < 1) ESL_XFAIL(eslEFORMAT, efp->errbuf, "Bad vector size %s", tok);
474 
475   if ((status = esl_fileparser_GetToken(efp, &tok, &toklen)) != eslOK) goto ERROR;
476   Q = atoi(tok);
477   if (Q < 1) ESL_XFAIL(eslEFORMAT, efp->errbuf, "Bad mixture number %s", tok);
478 
479   if ((dchl = esl_mixdchlet_Create(Q, K)) == NULL) goto ERROR;
480 
481   for (k = 0; k < Q; k++)
482     {
483       if ((status = esl_fileparser_GetToken(efp, &tok, &toklen)) != eslOK) goto ERROR;
484       dchl->q[k] = atof(tok);
485       if (dchl->q[k] < 0.0 || dchl->q[k] > 1.0)
486 	ESL_XFAIL(eslEFORMAT, efp->errbuf, "bad mixture coefficient %s", tok);
487 
488       for (a = 0; a < K; a++)
489 	{
490 	  if ((status = esl_fileparser_GetToken(efp, &tok, &toklen)) != eslOK) goto ERROR;
491 	  dchl->alpha[k][a] = atof(tok);
492 	  if (dchl->alpha[k][a] <= 0.0)
493 	    ESL_XFAIL(eslEFORMAT, efp->errbuf, "Dirichlet params must be positive, got %s", tok);
494 	}
495     }
496   esl_vec_DNorm(dchl->q, Q);
497   *ret_dchl = dchl;
498   return eslOK;
499 
500  ERROR:
501   *ret_dchl = NULL;
502   esl_mixdchlet_Destroy(dchl);
503   if (status == eslEOF) ESL_FAIL(eslEFORMAT, efp->errbuf, "Premature end of mixture dirichlet file");
504   return status;
505 }
506 
507 
508 /* Function:  esl_mixdchlet_Write()
509  * Synopsis:  Write a mixture Dirichlet to an open output stream.
510  *
511  * Purpose:   Write mixture Dirichlet <dchl> to open output stream <fp>,
512  *            with coefficients and parameters to four decimal places.
513  *
514  * Returns:   <eslOK> on success.
515  *
516  * Throws:    <eslEWRITE> on any write error, such as filled disk.
517  */
518 int
esl_mixdchlet_Write(FILE * fp,const ESL_MIXDCHLET * dchl)519 esl_mixdchlet_Write(FILE *fp, const ESL_MIXDCHLET *dchl)
520 {
521   int k,a;
522   int status;
523 
524   if ((status = esl_fprintf(fp, "%d %d\n", dchl->K, dchl->Q))      != eslOK) return status;
525   for (k = 0; k < dchl->Q; k++)
526     {
527       if ((status = esl_fprintf(fp, "%.4f ", dchl->q[k]))          != eslOK) return status;
528       for (a = 0; a < dchl->K; a++)
529 	if ((status = esl_fprintf(fp, "%.4f ", dchl->alpha[k][a])) != eslOK) return status;
530       if ((status = esl_fprintf(fp, "\n"))                         != eslOK) return status;
531     }
532   return eslOK;
533 }
534 
535 /* Function:  esl_mixdchlet_WriteJSON()
536  * Synopsis:  Write a mixture Dirichlet to an open output stream.
537  *
538  * Purpose:   Write mixture Dirichlet <dchl> to open output stream <fp>,
539  *            in a JSON format.
540  *
541  * Args:      fp   - open output stream
542  *            d    - mixture Dirichlet to write
543  *
544  * Returns:   <eslOK> on success.
545  *
546  * Throws:    <eslEWRITE> on any write error, such as filled disk.
547  */
548 int
esl_mixdchlet_WriteJSON(FILE * fp,const ESL_MIXDCHLET * dchl)549 esl_mixdchlet_WriteJSON(FILE *fp, const ESL_MIXDCHLET *dchl)
550 {
551   int k,a;
552   int status;
553 
554   if ((status = esl_fprintf(fp, "{\n"))                                     != eslOK) return status;
555   if ((status = esl_fprintf(fp, "      \"Q\" : %d,\n", dchl->Q))            != eslOK) return status;
556   if ((status = esl_fprintf(fp, "      \"K\" : %d,\n", dchl->K))            != eslOK) return status;
557   if ((status = esl_fprintf(fp, "      \"q\" : "))                          != eslOK) return status;
558   for (k = 0; k < dchl->Q; k++)
559     if ((status = esl_fprintf(fp, "%c %.4f", k==0? '[' : ',', dchl->q[k])) != eslOK) return status;
560   if ((status = esl_fprintf(fp, " ],\n"))                                   != eslOK) return status;
561 
562   for (k = 0; k < dchl->Q; k++)
563     {
564       if (k == 0) { if ((status = esl_fprintf(fp, "  \"alpha\" : [ "))      != eslOK) return status; }
565       else        { if ((status = esl_fprintf(fp, ",\n              "))     != eslOK) return status; }
566 
567       for (a = 0; a < dchl->K; a++)
568 	if ((status = esl_fprintf(fp, "%c %.4f", a==0? '[' : ',', dchl->alpha[k][a])) != eslOK) return status;
569       if ((status = esl_fprintf(fp, " ]"))                                  != eslOK) return status;
570     }
571   if ((status = esl_fprintf(fp, " ]\n}\n"))                                 != eslOK) return status;
572   return eslOK;
573 }
574 
575 
576 
577 /*****************************************************************
578  *# 5. Debugging and development tools
579  *****************************************************************/
580 
581 
582 /* Function:  esl_mixdchlet_Validate()
583  * Synopsis:  Validate a mixture Dirichlet structure
584  * Incept:    SRE, Sun 01 Jul 2018 [World Cup, Croatia v. Denmark]
585  *
586  * Purpose:   Validate the internals of an <ESL_MIXDCHLET>. If good, return <eslOK>.
587  *            If bad, return <eslFAIL>, and (if optional <errmsg> is provided by
588  *            caller) put an informative error message in <errmsg>.
589  *
590  * Args:      dchl   - ESL_MIXDCHLET to validate
591  *            errmsg - OPTIONAL: error message buffer of at least <eslERRBUFSIZE>; or <NULL>
592  *
593  * Returns:   <eslOK> on success, and <errmsg> (if provided) is set to
594  *            an empty string.
595  *
596  *            <eslFAIL> on failure, and <errmsg> (if provided) contains the reason
597  *            for the failure.
598  */
599 int
esl_mixdchlet_Validate(const ESL_MIXDCHLET * dchl,char * errmsg)600 esl_mixdchlet_Validate(const ESL_MIXDCHLET *dchl, char *errmsg)
601 {
602   int    k, a;
603   double sum;
604   double tol = 1e-6;
605   if (errmsg) *errmsg = 0;
606 
607   if (dchl->Q < 1) ESL_FAIL(eslFAIL, errmsg, "mixture dirichlet component number Q is %d, not >= 1", dchl->Q);
608   if (dchl->K < 1) ESL_FAIL(eslFAIL, errmsg, "mixture dirichlet alphabet size K is %d, not >= 1",    dchl->K);
609 
610   for (k = 0; k < dchl->Q; k++)
611     {
612       if (! isfinite(dchl->q[k] ) )              ESL_FAIL(eslFAIL, errmsg, "mixture coefficient [%d] = %g, not finite", k, dchl->q[k]);
613       if ( dchl->q[k] < 0.0 || dchl->q[k] > 1.0) ESL_FAIL(eslFAIL, errmsg, "mixture coefficient [%d] = %g, not a probability >= 0 && <= 1", k, dchl->q[k]);
614     }
615   sum = esl_vec_DSum(dchl->q, dchl->Q);
616   if (esl_DCompare( sum, 1.0, tol) != eslOK)
617     ESL_FAIL(eslFAIL, errmsg, "mixture coefficients sum to %g, not 1", sum);
618 
619   for (k = 0; k < dchl->Q; k++)
620     for (a = 0; a < dchl->K; a++)
621       {
622 	if (! isfinite(dchl->alpha[k][a])) ESL_FAIL(eslFAIL, errmsg, "dirichlet parameter [%d][%d] = %g, not finite", k, a, dchl->alpha[k][a]);
623 	if ( dchl->alpha[k][a] <= 0)       ESL_FAIL(eslFAIL, errmsg, "dirichlet parameter [%d][%d] = %g, not >0",     k, a, dchl->alpha[k][a]);
624       }
625   return eslOK;
626 }
627 
628 
629 
630 /* Function:  esl_mixdchlet_Compare()
631  * Synopsis:  Compare two mixture Dirichlets for equality.
632  *
633  * Purpose:   Compare mixture Dirichlet objects <d1> and <d2> for
634  *            equality, independent of the exact order of the
635  *            components. For real numbered values, equality is
636  *            defined by <esl_DCompare()> with a fractional tolerance
637  *            <tol>.
638  *
639  *            Order-independent, because when we fit a mixture
640  *            Dirichlet to data, the order of the components is
641  *            arbitrary. A maximum bipartite matching algorithm is
642  *            used to figure out the best matching order.
643  *
644  * Returns:   <eslOK> on equality; <eslFAIL> otherwise.
645  *
646  * Throws:    <eslEMEM> on allocation failure.
647  */
648 int
esl_mixdchlet_Compare(const ESL_MIXDCHLET * d1,const ESL_MIXDCHLET * d2,double tol)649 esl_mixdchlet_Compare(const ESL_MIXDCHLET *d1, const ESL_MIXDCHLET *d2, double tol)
650 {
651   int **A = NULL;   // 2D matrix w/ edges ij, TRUE when d1[i] ~= d2[j]
652   int   i,j;
653   int   nmatch;
654   int   status;
655 
656   if (d1->Q != d2->Q) return eslFAIL;
657   if (d1->K != d2->K) return eslFAIL;
658 
659   if ((A = esl_mat_ICreate(d1->Q, d2->Q)) == NULL) { status = eslEMEM; goto ERROR; }
660   esl_mat_ISet(A, d1->Q, d2->Q, FALSE);
661 
662   for (i = 0; i < d1->Q; i++)
663     for (j = 0; j < d2->Q; j++)
664       if ( esl_DCompare    (d1->q[i],     d2->q[j],            tol) == eslOK &&
665 	   esl_vec_DCompare(d1->alpha[i], d2->alpha[j], d1->K, tol) == eslOK)
666 	A[i][j] = TRUE;
667 
668   if ((status = esl_graph_MaxBipartiteMatch(A, d1->Q, d2->Q, NULL, &nmatch)) != eslOK) goto ERROR;
669 
670   status = (nmatch == d1->Q) ? eslOK: eslFAIL;
671   /* fallthrough */
672  ERROR:
673   esl_mat_IDestroy(A);
674   return status;
675 }
676 
677 
678 
679 
680 /* Function:  esl_mixdchlet_Dump()
681  *
682  * Purpose:   Dump the mixture Dirichlet <d>.
683  */
684 int
esl_mixdchlet_Dump(FILE * fp,const ESL_MIXDCHLET * dchl)685 esl_mixdchlet_Dump(FILE *fp, const ESL_MIXDCHLET *dchl)
686 {
687   int  k,a;  /* counters over mixture components, residues */
688 
689   fprintf(fp, "Mixture Dirichlet: Q=%d K=%d\n", dchl->Q, dchl->K);
690   for (k = 0; k < dchl->Q; k++)
691     {
692       fprintf(fp, "q[%d] %f\n", k, dchl->q[k]);
693       for (a = 0; a < dchl->K; a++)
694 	fprintf(fp, "alpha[%d][%d] %f\n", k, a, dchl->alpha[k][a]);
695     }
696   return eslOK;
697 }
698 
699 
700 /*****************************************************************
701  * 6. Unit tests
702  *****************************************************************/
703 #ifdef eslMIXDCHLET_TESTDRIVE
704 
705 /* utest_io
706  * Write a mixture out; read it back in; should be the same.
707  */
708 static void
utest_io(ESL_RANDOMNESS * rng)709 utest_io(ESL_RANDOMNESS *rng)
710 {
711   char            msg[]       = "esl_mixdchlet: io unit test failed";
712   int             Q           = 1 + esl_rnd_Roll(rng, 4);
713   int             K           = 1 + esl_rnd_Roll(rng, 4);
714   ESL_MIXDCHLET  *d1          = esl_mixdchlet_Create(Q, K);
715   ESL_MIXDCHLET  *d2          = NULL;
716   ESL_FILEPARSER *efp         = NULL;
717   FILE           *fp          = NULL;
718   float           tol         = 1e-3;
719   char            tmpfile[16] = "esltmpXXXXXX";
720   int             k,a;
721 
722   /* Create a random mixture Dirichlet */
723   if (esl_mixdchlet_Sample(rng, d1)   != eslOK) esl_fatal(msg);
724 
725   /* Truncate values to four digits after decimal;
726    * Write only saves that much
727    */
728   for (k = 0; k < d1->Q; k++)
729     {
730       d1->q[k] = ((int)(d1->q[k] * 1.e4)) / 1.e4;
731       for (a = 0; a < d1->K; a++)
732 	d1->alpha[k][a] = ((int)(d1->alpha[k][a] * 1.e4)) / 1.e4;
733     }
734 
735   /* Write it to a a named tmpfile.  */
736   if (esl_tmpfile_named(tmpfile, &fp) != eslOK) esl_fatal(msg);
737   if (esl_mixdchlet_Write(fp, d1)     != eslOK) esl_fatal(msg);
738   fclose(fp);
739 
740   /* Read it back in */
741   if ((fp = fopen(tmpfile, "r")) == NULL)        esl_fatal(msg);
742   if ((efp = esl_fileparser_Create(fp)) == NULL) esl_fatal(msg);
743   if (esl_mixdchlet_Read(efp, &d2) != eslOK)     esl_fatal(msg);
744   esl_fileparser_Destroy(efp);
745   fclose(fp);
746 
747   if (esl_mixdchlet_Compare(d1, d2, tol) != eslOK) esl_fatal(msg);
748 
749   esl_mixdchlet_Destroy(d2);
750   esl_mixdchlet_Destroy(d1);
751   remove(tmpfile);
752 }
753 
754 /* utest_fit
755  * Generate count data from a known mixture Dirichlet, fit a new one,
756  * and make sure they're similar.
757  *
758  * This test can fail stochastically. If <allow_badluck> is FALSE (the
759  * default), it will reseed <rng> to a predetermined seed that always
760  * works (here, 14). Because the <rng> state is changed, test driver
761  * should put any such utests last.
762  *
763  * This test is typically slow (~5-10sec). The special seed 14 is
764  * chosen to be an unusually fast one (~3s).
765  */
766 static void
utest_fit(ESL_RANDOMNESS * rng,int allow_badluck,int be_verbose)767 utest_fit(ESL_RANDOMNESS *rng, int allow_badluck, int be_verbose)
768 {
769   char            msg[]       = "esl_mixdchlet: utest_fit failed";
770   int             K           = 4;                            // alphabet size
771   int             N           = 10000;                        // number of count vectors to generate
772   int             nct         = 1000;                         // number of counts per vector
773   ESL_MIXDCHLET  *d0          = esl_mixdchlet_Create(2, K);   // true 2-component mixture Dirichlet (data generated from this)
774   ESL_MIXDCHLET  *dchl        = esl_mixdchlet_Create(2, K);   // estimated 2-component mixture Dirichlet
775   double         *p           = malloc(sizeof(double) * K);
776   double        **c           = esl_mat_DCreate(N, K);
777   double          nll0, nll;
778   int i,k,a;
779 
780   /* Suppress bad luck by default by fixing the RNG seed */
781   if (! allow_badluck) esl_randomness_Init(rng, 14);
782 
783   /* Create known 2-component mixture Dirichlet */
784   d0->q[0] = 0.7;
785   d0->q[1] = 0.3;
786   esl_vec_DSet(d0->alpha[0], d0->K, 1.0);   // component 0 = uniform
787   esl_vec_DSet(d0->alpha[1], d0->K, 10.0);  // component 1 = mode at 1/K
788 
789   /* Sample <N> observed count vectors, given d0 */
790   nll0 = 0;
791   for (i = 0; i < N; i++)
792     {
793       esl_vec_DSet(c[i], d0->K, 0.);
794       k = esl_rnd_DChoose(rng, d0->q, d0->Q);             // choose a mixture component
795       esl_dirichlet_DSample(rng, d0->alpha[k], d0->K, p); // sample a pvector
796       for (a = 0; a < nct; a++)
797 	c[i][ esl_rnd_DChoose(rng, p, d0->K) ] += 1.0;    // sample count vector
798       nll0 -= esl_mixdchlet_logp_c(d0, c[i]);
799     }
800 
801   if ( esl_mixdchlet_Sample(rng, dchl)        != eslOK) esl_fatal(msg);
802   if ( esl_mixdchlet_Fit(c, N, dchl, &nll)    != eslOK) esl_fatal(msg);
803 
804   if (be_verbose)
805     {
806       printf("True     (nll=%10.4g):\n", nll0);  esl_mixdchlet_Dump(stdout, d0);
807       printf("Inferred (nll=%10.4g):\n", nll);   esl_mixdchlet_Dump(stdout, dchl);
808     }
809   if ( esl_mixdchlet_Compare(d0, dchl, 0.1)   != eslOK) esl_fatal(msg);
810   if ( nll0 < nll )                                     esl_fatal(msg);
811 
812   esl_mat_DDestroy(c);
813   esl_mixdchlet_Destroy(dchl);
814   esl_mixdchlet_Destroy(d0);
815   free(p);
816 }
817 #endif // eslMIXDCHLET_TESTDRIVE
818 
819 
820 /*****************************************************************
821  * 7. Test driver
822  *****************************************************************/
823 #ifdef eslMIXDCHLET_TESTDRIVE
824 
825 #include "easel.h"
826 #include "esl_fileparser.h"
827 #include "esl_getopts.h"
828 #include "esl_random.h"
829 #include "esl_dirichlet.h"
830 
831 static ESL_OPTIONS options[] = {
832   /* name           type      default  env  range toggles reqs incomp  help                                       docgroup*/
833   { "-h",        eslARG_NONE,   FALSE,  NULL, NULL,  NULL,  NULL, NULL, "show brief help on version and usage",             0 },
834   { "-s",        eslARG_INT,      "0",  NULL, NULL,  NULL,  NULL, NULL, "set random number seed to <n>",                    0 },
835   { "-x",        eslARG_NONE,   FALSE,  NULL, NULL,  NULL,  NULL, NULL, "allow bad luck (expected stochastic failures)",    0 },
836   { "-v",        eslARG_NONE,   FALSE,  NULL, NULL,  NULL,  NULL, NULL, "be more verbose"              ,                    0 },
837   {  0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
838 };
839 static char usage[]  = "[-options]";
840 static char banner[] = "test driver for mixdchlet module";
841 
842 int
main(int argc,char ** argv)843 main(int argc, char **argv)
844 {
845   ESL_GETOPTS    *go   = esl_getopts_CreateDefaultApp(options, 0, argc, argv, banner, usage);
846   ESL_RANDOMNESS *rng  = esl_randomness_Create(esl_opt_GetInteger(go, "-s"));
847   int      be_verbose  = esl_opt_GetBoolean(go, "-v");
848   int   allow_badluck  = esl_opt_GetBoolean(go, "-x");  // if a utest can fail just by chance, let it, instead of suppressing
849 
850   fprintf(stderr, "## %s\n", argv[0]);
851   fprintf(stderr, "#  rng seed = %" PRIu32 "\n", esl_randomness_GetSeed(rng));
852 
853   utest_io (rng);
854 
855   // Tests that can fail stochastically go last, because they reset the RNG seed by default.
856   utest_fit(rng, allow_badluck, be_verbose);
857 
858   fprintf(stderr, "#  status = ok\n");
859 
860   esl_randomness_Destroy(rng);
861   esl_getopts_Destroy(go);
862   return eslOK;
863 }
864 #endif /*eslMIXDCHLET_TESTDRIVE*/
865 /*--------------------- end, test driver ------------------------*/
866 
867 
868 /*****************************************************************
869  * x. Example.
870  *****************************************************************/
871 #ifdef eslMIXDCHLET_EXAMPLE
872 
873 #include "easel.h"
874 #include "esl_fileparser.h"
875 #include "esl_getopts.h"
876 
877 static ESL_OPTIONS options[] = {
878   /* name           type      default  env  range toggles reqs incomp  help                                       docgroup*/
879   { "-h",        eslARG_NONE,   FALSE,  NULL, NULL,  NULL,  NULL, NULL, "show brief help on version and usage",             0 },
880   {  0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
881 };
882 static char usage[]  = "[-options] <mixdchlet_file> <counts_file>";
883 static char banner[] = "example driver for mixdchlet module: log likelihood of count data";
884 
885 int
main(int argc,char ** argv)886 main(int argc, char **argv)
887 {
888   ESL_GETOPTS    *go     = esl_getopts_CreateDefaultApp(options, 2, argc, argv, banner, usage);
889   char           *dfile  = esl_opt_GetArg(go, 1);
890   char           *ctfile = esl_opt_GetArg(go, 2);
891   ESL_FILEPARSER *efp    = NULL;
892   ESL_MIXDCHLET  *dchl   = NULL;
893   double         *ct     = NULL;   // one countvector read from ctfile at a time
894   char           *tok    = NULL;
895   int             toklen = 0;
896   int             a;
897   double          nll    = 0;
898   int             status;
899 
900   /* Read mixture Dirichlet */
901   if ( esl_fileparser_Open(dfile, NULL, &efp) != eslOK) esl_fatal("failed to open %s for reading", dfile);
902   esl_fileparser_SetCommentChar(efp, '#');
903   if ( esl_mixdchlet_Read(efp, &dchl)         != eslOK) esl_fatal("failed to parse %s\n  %s", dfile, efp->errbuf);
904   esl_fileparser_Close(efp);
905   efp = NULL;
906 
907   /* Read count vectors one at a time, increment nll */
908   if ( esl_fileparser_Open(ctfile, NULL, &efp) != eslOK) esl_fatal("failed to open %s for reading", ctfile);
909   esl_fileparser_SetCommentChar(efp, '#');
910   ct = malloc(sizeof(double) * dchl->K);
911   while ((status = esl_fileparser_NextLine(efp)) == eslOK)
912     {
913       a = 0; // counter over fields on line, ct[a=0..K-1].
914       while ((status = esl_fileparser_GetTokenOnLine(efp, &tok, &toklen)) == eslOK)
915        {
916 	 if (a == dchl->K)          esl_fatal("parse failed, %s:%d: > K=%d fields on line", ctfile, efp->linenumber, dchl->K);
917 	 if (! esl_str_IsReal(tok)) esl_fatal("parse failed, %s:%d: field %d (%s) not a real number", ctfile, efp->linenumber, a+1, tok);
918 	 ct[a++] = atof(tok);
919        }
920 
921       nll += esl_mixdchlet_logp_c(dchl, ct);
922     }
923   esl_fileparser_Close(efp);
924 
925   printf("nll = %g\n", -nll);
926 
927   free(ct);
928   esl_mixdchlet_Destroy(dchl);
929   esl_getopts_Destroy(go);
930 }
931 #endif /*eslMIXDCHLET_EXAMPLE*/
932 
933 
934 
935 
936