1 /* Mixture Dirichlet densities
2 *
3 * Contents:
4 * 1. <ESL_MIXDCHLET> object
5 * 2. Likelihoods, posteriors, inference
6 * 3. Maximum likelihood fitting to count data
7 * 4. Reading/writing mixture Dirichlet files
8 * 5. Debugging and development tools
9 * 6. Unit tests
10 * 7. Test driver
11 * 8. Example
12 *
13 * See also:
14 * esl_dirichlet : simple Dirichlet densities
15 * esl-mixdchlet miniapp : fitting and more
16 */
17 #include "esl_config.h"
18
19 #include <stdlib.h>
20 #include <stdio.h>
21
22 #include "easel.h"
23 #include "esl_dirichlet.h"
24 #include "esl_fileparser.h"
25 #include "esl_graph.h"
26 #include "esl_matrixops.h"
27 #include "esl_minimizer.h"
28 #include "esl_random.h"
29 #include "esl_stats.h"
30 #include "esl_vectorops.h"
31 #include "esl_mixdchlet.h"
32
33
34 /*****************************************************************
35 *# 1. <ESL_MIXDCHLET> object
36 *****************************************************************/
37
38 /* Function: esl_mixdchlet_Create()
39 *
40 * Purpose: Create a new mixture Dirichlet prior with <Q> components,
41 * each with <K> parameters.
42 *
43 * Returns: ptr to new <ESL_MIXDCHLET> on success.
44 *
45 * Throws: NULL on allocation failure.
46 */
47 ESL_MIXDCHLET *
esl_mixdchlet_Create(int Q,int K)48 esl_mixdchlet_Create(int Q, int K)
49 {
50 ESL_MIXDCHLET *dchl = NULL;
51 int status;
52
53 ESL_DASSERT1( (Q > 0) );
54 ESL_DASSERT1( (K > 0) );
55
56 ESL_ALLOC(dchl, sizeof(ESL_MIXDCHLET));
57 dchl->q = NULL;
58 dchl->alpha = NULL;
59 dchl->postq = NULL;
60
61 ESL_ALLOC(dchl->q, sizeof(double) * Q);
62 ESL_ALLOC(dchl->postq, sizeof(double) * Q);
63 if ((dchl->alpha = esl_mat_DCreate(Q,K)) == NULL) goto ERROR;
64
65 dchl->Q = Q;
66 dchl->K = K;
67 return dchl;
68
69 ERROR:
70 esl_mixdchlet_Destroy(dchl);
71 return NULL;
72 }
73
74
75
76 /* Function: esl_mixdchlet_Destroy()
77 * Synopsis: Free a mixture Dirichlet.
78 */
79 void
esl_mixdchlet_Destroy(ESL_MIXDCHLET * dchl)80 esl_mixdchlet_Destroy(ESL_MIXDCHLET *dchl)
81 {
82 if (dchl)
83 {
84 free(dchl->q);
85 esl_mat_DDestroy(dchl->alpha);
86 free(dchl->postq);
87 free(dchl);
88 }
89 }
90
91
92
93 /*****************************************************************
94 * 2. Likelihoods, posteriors, inference
95 *****************************************************************/
96
97
98 /* mixchlet_postq()
99 * Calculate P(q | c), the posterior probability of component q.
100 */
101 static void
mixdchlet_postq(ESL_MIXDCHLET * dchl,double * c)102 mixdchlet_postq(ESL_MIXDCHLET *dchl, double *c)
103 {
104 int k;
105 for (k = 0; k < dchl->Q; k++)
106 if (dchl->q[k] > 0.) dchl->postq[k] = log(dchl->q[k]) + esl_dirichlet_logpdf_c(c, dchl->alpha[k], dchl->K);
107 else dchl->postq[k] = -eslINFINITY;
108 esl_vec_DLogNorm(dchl->postq, dchl->Q);
109 }
110
111
112
113 /* Function: esl_mixdchlet_logp_c()
114 *
115 * Purpose: Given observed count vector $c[0..K-1]$ and a mixture
116 * Dirichlet <dchl>, calculate $\log P(c \mid \theta)$.
117 *
118 * Args: dchl : mixture Dirichlet
119 * c : count vector, [0..K-1]
120 *
121 * Returns: $\log P(c \mid \theta)$
122 *
123 * Note: Only because a workspace in <dchl> is used, you can't
124 * declare <dchl> to be const.
125 */
126 double
esl_mixdchlet_logp_c(ESL_MIXDCHLET * dchl,double * c)127 esl_mixdchlet_logp_c(ESL_MIXDCHLET *dchl, double *c)
128 {
129 int k;
130 for (k = 0; k < dchl->Q; k++)
131 if (dchl->q[k] > 0.) dchl->postq[k] = log(dchl->q[k]) + esl_dirichlet_logpdf_c(c, dchl->alpha[k], dchl->K);
132 else dchl->postq[k] = -eslINFINITY;
133 return esl_vec_DLogSum(dchl->postq, dchl->Q);
134 }
135
136
137
138 /* Function: esl_mixdchlet_MPParameters()
139 * Synopsis: Calculate mean posterior parameters from a count vector
140 *
141 * Purpose: Given a mixture Dirichlet prior <dchl> and observed
142 * countvector <c> of length <dchl->K>, calculate mean
143 * posterior parameter estimates <p>. Caller provides the
144 * storage <p>, allocated for at least <dchl->K> parameters.
145 *
146 * Returns: <eslOK> on success, and <p> contains mean posterior
147 * probability parameter estimates.
148 *
149 * Note: Only because a workspace in <dchl> is used, you can't
150 * declare <dchl> to be const.
151 */
152 int
esl_mixdchlet_MPParameters(ESL_MIXDCHLET * dchl,double * c,double * p)153 esl_mixdchlet_MPParameters(ESL_MIXDCHLET *dchl, double *c, double *p)
154 {
155 int k,a; // indices over components, residues
156 double totc;
157 double totalpha;
158
159 /* Calculate posterior prob P(k | c) of each component k given count vector c. */
160 mixdchlet_postq(dchl, c);
161
162 /* Compute mean posterior estimates for probability parameters */
163 totc = esl_vec_DSum(c, dchl->K);
164 esl_vec_DSet(p, dchl->K, 0.);
165 for (k = 0; k < dchl->Q; k++)
166 {
167 totalpha = esl_vec_DSum(dchl->alpha[k], dchl->K);
168 for (a = 0; a < dchl->K; a++)
169 p[a] += dchl->postq[k] * (c[a] + dchl->alpha[k][a]) / (totc + totalpha);
170 }
171 /* should be normalized already, but for good measure: */
172 esl_vec_DNorm(p, dchl->K);
173 return eslOK;
174 }
175
176
177
178 /*****************************************************************
179 * 3. Maximum likelihood fitting to count data
180 *****************************************************************/
181 /* This structure is used to shuttle the data into minimizer's generic
182 * (void *) API for all aux data
183 */
184 struct mixdchlet_data {
185 ESL_MIXDCHLET *dchl; /* dirichlet mixture parameters */
186 double **c; /* count vector array [0..N-1][0..K-1] */
187 int N; /* number of countvectors */
188 };
189
190 /*****************************************************************
191 * Parameter vector packing/unpacking
192 *
193 * The Easel conjugate gradient code is a general optimizer. It takes
194 * a single parameter vector <p>, where the values are unconstrained
195 * real numbers.
196 *
197 * We're optimizing a mixture Dirichlet with two kinds of parameters.
198 * q[k] are mixture coefficients, constrained to be >= 0 and \sum_k
199 * q[k] = 1. alpha[k][a] are the Dirichlet parameters for component
200 * k, constrained to be > 0.
201 *
202 * So we use a c.o.v. to get the coefficients and parameters in terms
203 * of unconstrained reals lambda and beta:
204 * mixture coefficients: q_k = exp(lambda_k) / \sum_j exp(lambda_j)
205 * Dirichlet parameters: alpha_a = exp(beta_a)
206 *
207 * And we pack them all in one parameter vector, lambdas first:
208 * [0 ... Q-1] [0 ... K-1] [0 ... K-1] ...
209 * lambda's beta_0 beta_1 ...
210 *
211 * The parameter vector p therefore has length Q(K+1), and is accessed as:
212 * mixture coefficient lambda[k] is at p[k]
213 * Dirichlet param beta[k][a] is at p[Q + q*K + a].
214 */
215 static void
mixdchlet_pack_paramvector(ESL_MIXDCHLET * dchl,double * p)216 mixdchlet_pack_paramvector(ESL_MIXDCHLET *dchl, double *p)
217 {
218 int j = 0; /* counter in packed parameter vector <p> */
219 int k,a; /* indices over components, residues */
220
221 /* mixture coefficients */
222 for (k = 0; k < dchl->Q; k++)
223 p[j++] = log(dchl->q[k]);
224
225 /* Dirichlet parameters */
226 for (k = 0; k < dchl->Q; k++)
227 for (a = 0; a < dchl->K; a++)
228 p[j++] = log(dchl->alpha[k][a]);
229
230 ESL_DASSERT1(( j == dchl->Q * (dchl->K + 1)) );
231 }
232
233 /* Same as above but in reverse: given parameter vector <p>,
234 * do appropriate c.o.v. back to desired parameter space, and
235 * storing in mixdchlet <dchl>.
236 */
237 static void
mixdchlet_unpack_paramvector(double * p,ESL_MIXDCHLET * dchl)238 mixdchlet_unpack_paramvector(double *p, ESL_MIXDCHLET *dchl)
239 {
240 int j = 0; /* counter in packed parameter vector <p> */
241 int k,a; /* indices over components, residues */
242
243 /* mixture coefficients */
244 for (k = 0; k < dchl->Q; k++)
245 dchl->q[k] = exp(p[j++]);
246 esl_vec_DNorm(dchl->q, dchl->Q);
247
248 /* Dirichlet parameters */
249 for (k = 0; k < dchl->Q; k++)
250 for (a = 0; a < dchl->K; a++)
251 dchl->alpha[k][a] = exp(p[j++]);
252
253 ESL_DASSERT1(( j == dchl->Q * (dchl->K + 1)) );
254 }
255
256 /* The negative log likelihood function to be minimized by ML fitting. */
257 static double
mixdchlet_nll(double * p,int np,void * dptr)258 mixdchlet_nll(double *p, int np, void *dptr)
259 {
260 ESL_UNUSED(np); // parameter number <np> must be an arg, dictated by conj gradient API.
261 struct mixdchlet_data *data = (struct mixdchlet_data *) dptr;
262 ESL_MIXDCHLET *dchl = data->dchl;
263 double nll = 0.;
264 int i;
265
266 mixdchlet_unpack_paramvector(p, dchl);
267 for (i = 0; i < data->N; i++)
268 nll -= esl_mixdchlet_logp_c(dchl, data->c[i]);
269 return nll;
270 }
271
272 /* The gradient of the NLL w.r.t. each free parameter in p. */
273 static void
mixdchlet_gradient(double * p,int np,void * dptr,double * dp)274 mixdchlet_gradient(double *p, int np, void *dptr, double *dp)
275 {
276 struct mixdchlet_data *data = (struct mixdchlet_data *) dptr;
277 ESL_MIXDCHLET *dchl = data->dchl;
278 double sum_alpha; // |alpha_k|
279 double sum_c; // |c_i|
280 double psi1; // \Psi(c_ia + alpha_ka)
281 double psi2; // \Psi( |c_i| + |alpha_k| )
282 double psi3; // \Psi( |alpha_k| )
283 double psi4; // \Psi( alpha_ka )
284 int i,j,k,a; // indices over countvectors, unconstrained parameters, components, residues
285
286 mixdchlet_unpack_paramvector(p, dchl);
287 esl_vec_DSet(dp, np, 0.);
288 for (i = 0; i < data->N; i++)
289 {
290 mixdchlet_postq(dchl, data->c[i]); // d->postq[q] is now P(q | c_i, theta)
291 sum_c = esl_vec_DSum(data->c[i], dchl->K); // |c_i|
292
293 /* mixture coefficient gradient */
294 j = 0;
295 for (k = 0; k < dchl->Q; k++)
296 dp[j++] -= dchl->postq[k] - dchl->q[k];
297
298 for (k = 0; k < dchl->Q; k++)
299 {
300 sum_alpha = esl_vec_DSum(dchl->alpha[k], dchl->K);
301 esl_stats_Psi( sum_alpha + sum_c, &psi2);
302 esl_stats_Psi( sum_alpha, &psi3);
303 for (a = 0; a < dchl->K; a++)
304 {
305 esl_stats_Psi( dchl->alpha[k][a] + data->c[i][a], &psi1);
306 esl_stats_Psi( dchl->alpha[k][a], &psi4);
307 dp[j++] -= dchl->alpha[k][a] * dchl->postq[k] * (psi1 - psi2 + psi3 - psi4);
308 }
309 }
310 }
311 }
312
313 /* Function: esl_mixdchlet_Fit()
314 *
315 * Purpose: Given many count vectors <c> (<N> of them) and an initial
316 * guess <dchl> for a mixture Dirichlet, find maximum likelihood
317 * parameters by conjugate gradient descent optimization,
318 * updating <dchl>. Optionally, return the final negative log likelihood
319 * in <*opt_nll>.
320 *
321 * Args: c : count vectors c[0..N-1][0..K-1]
322 * N : number of count vectors; N>0
323 * dchl : initial guess, updated to the fitted model upon return.
324 * opt_nll : OPTIONAL: final negative log likelihood
325 *
326 * Returns: <eslOK> on success, <dchl> contains the fitted
327 * mixture Dirichlet, and <*opt_nll> (if passed) contains the final NLL.
328 *
329 * <eslENOHALT> if the fit fails to converge in a reasonable
330 * number of iterations (in <esl_min_ConjugateGradientDescent()>,
331 * default <max_iterations> is currently 100), but an answer
332 * is still in <dchl> and <*opt_nll>.
333 *
334 * <eslEINVAL> if N < 1. (Caller needs to provide data, but
335 * it's possible that an input file might not contain any,
336 * and we want to make sure to bark about it.)
337 *
338 * Throws: <eslEMEM> on allocation error, <dchl> is left in
339 * in its initial state, and <*opt_nll> (if passed) is -inf.
340 *
341 * <eslECORRUPT> if <dchl> isn't a valid mixture Dirichlet.
342 */
343 int
esl_mixdchlet_Fit(double ** c,int N,ESL_MIXDCHLET * dchl,double * opt_nll)344 esl_mixdchlet_Fit(double **c, int N, ESL_MIXDCHLET *dchl, double *opt_nll)
345 {
346 ESL_MIN_CFG *cfg = NULL;
347 ESL_MIN_DAT *dat = NULL;
348 struct mixdchlet_data data;
349 double *p = NULL; // parameter vector [0..nparam-1], for CG descent
350 int nparam = dchl->Q * (dchl->K + 1);
351 double fx;
352 int status;
353
354 cfg = esl_min_cfg_Create(nparam);
355 if (! cfg) { status = eslEMEM; goto ERROR; }
356 cfg->cg_rtol = 3e-5;
357 cfg->brent_rtol = 1e-2;
358 esl_vec_DSet(cfg->u, nparam, 0.1);
359
360 dat = esl_min_dat_Create(cfg);
361
362 if (N < 1) return eslEINVAL;
363 #if (eslDEBUGLEVEL >= 1)
364 if ( esl_mixdchlet_Validate(dchl, NULL) != eslOK) ESL_EXCEPTION(eslECORRUPT, "initial mixture is invalid");
365 #endif
366 ESL_ALLOC(p, sizeof(double) * nparam);
367
368 /* <data> is a wrapper that shuttles count data, theta into CG */
369 data.dchl = dchl;
370 data.c = c;
371 data.N = N;
372
373 /* initialize <p> */
374 mixdchlet_pack_paramvector(dchl, p);
375
376 /* Feed it all to the mighty optimizer */
377 status = esl_min_ConjugateGradientDescent(cfg, p, nparam,
378 &mixdchlet_nll,
379 &mixdchlet_gradient,
380 (void *) (&data), &fx, dat);
381 if (status != eslENOHALT && status != eslOK) goto ERROR; // too many iterations? treat it as "good enough".
382
383 /* Convert the final parameter vector back */
384 mixdchlet_unpack_paramvector(p, dchl);
385
386 esl_min_dat_Dump(stdout, dat);
387
388 free(p);
389 esl_min_cfg_Destroy(cfg);
390 esl_min_dat_Destroy(dat);
391 if (opt_nll) *opt_nll = fx;
392 return eslOK;
393
394 ERROR:
395 free(p);
396 esl_min_cfg_Destroy(cfg);
397 esl_min_dat_Destroy(dat);
398 if (opt_nll) *opt_nll = -eslINFINITY;
399 return status;
400 }
401
402
403 /* Function: esl_mixdchlet_Sample()
404 * Synopsis: Sample a random (perhaps initial) ESL_MIXDCHLET
405 * Incept: SRE, Sun 01 Jul 2018 [Hamilton]
406 *
407 * Purpose: Use random number generator <rng> to sample a
408 * <ESL_MIXDCHLET> that's already been created for
409 * <dchl->Q> components and alphabet size <dchl->K>. The
410 * random Dirichlet parameters are sampled uniformly on a
411 * (0,2) open interval, and the mixture coefficients are
412 * sampled uniformly.
413 *
414 * Returns: <eslOK> on success, and <dchl> contains the sampled
415 * model.
416 */
417 int
esl_mixdchlet_Sample(ESL_RANDOMNESS * rng,ESL_MIXDCHLET * dchl)418 esl_mixdchlet_Sample(ESL_RANDOMNESS *rng, ESL_MIXDCHLET *dchl)
419 {
420 int k,a;
421
422 esl_dirichlet_DSampleUniform(rng, dchl->Q, dchl->q);
423 for (k = 0; k < dchl->Q; k++)
424 for (a = 0; a < dchl->K; a++)
425 dchl->alpha[k][a] = 2.0 * esl_rnd_UniformPositive(rng);
426 return eslOK;
427 }
428
429
430 /*****************************************************************
431 *# 4. Reading/writing mixture Dirichlet files
432 *****************************************************************/
433
434 /* Function: esl_mixdchlet_Read()
435 *
436 * Purpose: Reads a mixture Dirichlet from an open stream <efp>, using the
437 * <ESL_FILEPARSER> token-based parser.
438 *
439 * The first two tokens are <K>, the length of the Dirichlet parameter
440 * vector(s), and <Q>, the number of mixture components. Then for
441 * each of the <Q> mixture components <k>, it reads a mixture coefficient
442 * <q[k]> followed by <K> Dirichlet parameters <alpha[k][a=0..K-1]>.
443 *
444 * This function may be called more than once on the same open file,
445 * to read multiple different mixture Dirichlets from it (transitions,
446 * match emissions, insert emissions, for example).
447 *
448 * Note: One reason this function takes an ESL_FILEPARSER instead of
449 * a filename or an open FILE pointer is that file format errors
450 * in Easel are non-fatal "normal" errors, and we want to record
451 * an informative error message. The ESL_FILEPARSER has an error
452 * buffer for this purpose.
453 *
454 * Returns: <eslOK> on success, and <ret_dchl> contains a new <ESL_MIXDCHLET> object
455 * that the caller is responsible for free'ing.
456 *
457 * <eslEFORMAT> on 'normal' parse failure, in which case <efp->errbuf>
458 * contains an informative diagnostic message, and <efp->linenumber>
459 * contains the linenumber at which the parse failed.
460 */
461 int
esl_mixdchlet_Read(ESL_FILEPARSER * efp,ESL_MIXDCHLET ** ret_dchl)462 esl_mixdchlet_Read(ESL_FILEPARSER *efp, ESL_MIXDCHLET **ret_dchl)
463 {
464 ESL_MIXDCHLET *dchl = NULL;
465 int Q,K; /* number of components, alphabet size */
466 char *tok; /* ptr to a whitespace-delim, noncomment token */
467 int toklen; /* length of a parsed token */
468 int k,a; /* index over components, symbols */
469 int status;
470
471 if ((status = esl_fileparser_GetToken(efp, &tok, &toklen)) != eslOK) goto ERROR;
472 K = atoi(tok);
473 if (K < 1) ESL_XFAIL(eslEFORMAT, efp->errbuf, "Bad vector size %s", tok);
474
475 if ((status = esl_fileparser_GetToken(efp, &tok, &toklen)) != eslOK) goto ERROR;
476 Q = atoi(tok);
477 if (Q < 1) ESL_XFAIL(eslEFORMAT, efp->errbuf, "Bad mixture number %s", tok);
478
479 if ((dchl = esl_mixdchlet_Create(Q, K)) == NULL) goto ERROR;
480
481 for (k = 0; k < Q; k++)
482 {
483 if ((status = esl_fileparser_GetToken(efp, &tok, &toklen)) != eslOK) goto ERROR;
484 dchl->q[k] = atof(tok);
485 if (dchl->q[k] < 0.0 || dchl->q[k] > 1.0)
486 ESL_XFAIL(eslEFORMAT, efp->errbuf, "bad mixture coefficient %s", tok);
487
488 for (a = 0; a < K; a++)
489 {
490 if ((status = esl_fileparser_GetToken(efp, &tok, &toklen)) != eslOK) goto ERROR;
491 dchl->alpha[k][a] = atof(tok);
492 if (dchl->alpha[k][a] <= 0.0)
493 ESL_XFAIL(eslEFORMAT, efp->errbuf, "Dirichlet params must be positive, got %s", tok);
494 }
495 }
496 esl_vec_DNorm(dchl->q, Q);
497 *ret_dchl = dchl;
498 return eslOK;
499
500 ERROR:
501 *ret_dchl = NULL;
502 esl_mixdchlet_Destroy(dchl);
503 if (status == eslEOF) ESL_FAIL(eslEFORMAT, efp->errbuf, "Premature end of mixture dirichlet file");
504 return status;
505 }
506
507
508 /* Function: esl_mixdchlet_Write()
509 * Synopsis: Write a mixture Dirichlet to an open output stream.
510 *
511 * Purpose: Write mixture Dirichlet <dchl> to open output stream <fp>,
512 * with coefficients and parameters to four decimal places.
513 *
514 * Returns: <eslOK> on success.
515 *
516 * Throws: <eslEWRITE> on any write error, such as filled disk.
517 */
518 int
esl_mixdchlet_Write(FILE * fp,const ESL_MIXDCHLET * dchl)519 esl_mixdchlet_Write(FILE *fp, const ESL_MIXDCHLET *dchl)
520 {
521 int k,a;
522 int status;
523
524 if ((status = esl_fprintf(fp, "%d %d\n", dchl->K, dchl->Q)) != eslOK) return status;
525 for (k = 0; k < dchl->Q; k++)
526 {
527 if ((status = esl_fprintf(fp, "%.4f ", dchl->q[k])) != eslOK) return status;
528 for (a = 0; a < dchl->K; a++)
529 if ((status = esl_fprintf(fp, "%.4f ", dchl->alpha[k][a])) != eslOK) return status;
530 if ((status = esl_fprintf(fp, "\n")) != eslOK) return status;
531 }
532 return eslOK;
533 }
534
535 /* Function: esl_mixdchlet_WriteJSON()
536 * Synopsis: Write a mixture Dirichlet to an open output stream.
537 *
538 * Purpose: Write mixture Dirichlet <dchl> to open output stream <fp>,
539 * in a JSON format.
540 *
541 * Args: fp - open output stream
542 * d - mixture Dirichlet to write
543 *
544 * Returns: <eslOK> on success.
545 *
546 * Throws: <eslEWRITE> on any write error, such as filled disk.
547 */
548 int
esl_mixdchlet_WriteJSON(FILE * fp,const ESL_MIXDCHLET * dchl)549 esl_mixdchlet_WriteJSON(FILE *fp, const ESL_MIXDCHLET *dchl)
550 {
551 int k,a;
552 int status;
553
554 if ((status = esl_fprintf(fp, "{\n")) != eslOK) return status;
555 if ((status = esl_fprintf(fp, " \"Q\" : %d,\n", dchl->Q)) != eslOK) return status;
556 if ((status = esl_fprintf(fp, " \"K\" : %d,\n", dchl->K)) != eslOK) return status;
557 if ((status = esl_fprintf(fp, " \"q\" : ")) != eslOK) return status;
558 for (k = 0; k < dchl->Q; k++)
559 if ((status = esl_fprintf(fp, "%c %.4f", k==0? '[' : ',', dchl->q[k])) != eslOK) return status;
560 if ((status = esl_fprintf(fp, " ],\n")) != eslOK) return status;
561
562 for (k = 0; k < dchl->Q; k++)
563 {
564 if (k == 0) { if ((status = esl_fprintf(fp, " \"alpha\" : [ ")) != eslOK) return status; }
565 else { if ((status = esl_fprintf(fp, ",\n ")) != eslOK) return status; }
566
567 for (a = 0; a < dchl->K; a++)
568 if ((status = esl_fprintf(fp, "%c %.4f", a==0? '[' : ',', dchl->alpha[k][a])) != eslOK) return status;
569 if ((status = esl_fprintf(fp, " ]")) != eslOK) return status;
570 }
571 if ((status = esl_fprintf(fp, " ]\n}\n")) != eslOK) return status;
572 return eslOK;
573 }
574
575
576
577 /*****************************************************************
578 *# 5. Debugging and development tools
579 *****************************************************************/
580
581
582 /* Function: esl_mixdchlet_Validate()
583 * Synopsis: Validate a mixture Dirichlet structure
584 * Incept: SRE, Sun 01 Jul 2018 [World Cup, Croatia v. Denmark]
585 *
586 * Purpose: Validate the internals of an <ESL_MIXDCHLET>. If good, return <eslOK>.
587 * If bad, return <eslFAIL>, and (if optional <errmsg> is provided by
588 * caller) put an informative error message in <errmsg>.
589 *
590 * Args: dchl - ESL_MIXDCHLET to validate
591 * errmsg - OPTIONAL: error message buffer of at least <eslERRBUFSIZE>; or <NULL>
592 *
593 * Returns: <eslOK> on success, and <errmsg> (if provided) is set to
594 * an empty string.
595 *
596 * <eslFAIL> on failure, and <errmsg> (if provided) contains the reason
597 * for the failure.
598 */
599 int
esl_mixdchlet_Validate(const ESL_MIXDCHLET * dchl,char * errmsg)600 esl_mixdchlet_Validate(const ESL_MIXDCHLET *dchl, char *errmsg)
601 {
602 int k, a;
603 double sum;
604 double tol = 1e-6;
605 if (errmsg) *errmsg = 0;
606
607 if (dchl->Q < 1) ESL_FAIL(eslFAIL, errmsg, "mixture dirichlet component number Q is %d, not >= 1", dchl->Q);
608 if (dchl->K < 1) ESL_FAIL(eslFAIL, errmsg, "mixture dirichlet alphabet size K is %d, not >= 1", dchl->K);
609
610 for (k = 0; k < dchl->Q; k++)
611 {
612 if (! isfinite(dchl->q[k] ) ) ESL_FAIL(eslFAIL, errmsg, "mixture coefficient [%d] = %g, not finite", k, dchl->q[k]);
613 if ( dchl->q[k] < 0.0 || dchl->q[k] > 1.0) ESL_FAIL(eslFAIL, errmsg, "mixture coefficient [%d] = %g, not a probability >= 0 && <= 1", k, dchl->q[k]);
614 }
615 sum = esl_vec_DSum(dchl->q, dchl->Q);
616 if (esl_DCompare( sum, 1.0, tol) != eslOK)
617 ESL_FAIL(eslFAIL, errmsg, "mixture coefficients sum to %g, not 1", sum);
618
619 for (k = 0; k < dchl->Q; k++)
620 for (a = 0; a < dchl->K; a++)
621 {
622 if (! isfinite(dchl->alpha[k][a])) ESL_FAIL(eslFAIL, errmsg, "dirichlet parameter [%d][%d] = %g, not finite", k, a, dchl->alpha[k][a]);
623 if ( dchl->alpha[k][a] <= 0) ESL_FAIL(eslFAIL, errmsg, "dirichlet parameter [%d][%d] = %g, not >0", k, a, dchl->alpha[k][a]);
624 }
625 return eslOK;
626 }
627
628
629
630 /* Function: esl_mixdchlet_Compare()
631 * Synopsis: Compare two mixture Dirichlets for equality.
632 *
633 * Purpose: Compare mixture Dirichlet objects <d1> and <d2> for
634 * equality, independent of the exact order of the
635 * components. For real numbered values, equality is
636 * defined by <esl_DCompare()> with a fractional tolerance
637 * <tol>.
638 *
639 * Order-independent, because when we fit a mixture
640 * Dirichlet to data, the order of the components is
641 * arbitrary. A maximum bipartite matching algorithm is
642 * used to figure out the best matching order.
643 *
644 * Returns: <eslOK> on equality; <eslFAIL> otherwise.
645 *
646 * Throws: <eslEMEM> on allocation failure.
647 */
648 int
esl_mixdchlet_Compare(const ESL_MIXDCHLET * d1,const ESL_MIXDCHLET * d2,double tol)649 esl_mixdchlet_Compare(const ESL_MIXDCHLET *d1, const ESL_MIXDCHLET *d2, double tol)
650 {
651 int **A = NULL; // 2D matrix w/ edges ij, TRUE when d1[i] ~= d2[j]
652 int i,j;
653 int nmatch;
654 int status;
655
656 if (d1->Q != d2->Q) return eslFAIL;
657 if (d1->K != d2->K) return eslFAIL;
658
659 if ((A = esl_mat_ICreate(d1->Q, d2->Q)) == NULL) { status = eslEMEM; goto ERROR; }
660 esl_mat_ISet(A, d1->Q, d2->Q, FALSE);
661
662 for (i = 0; i < d1->Q; i++)
663 for (j = 0; j < d2->Q; j++)
664 if ( esl_DCompare (d1->q[i], d2->q[j], tol) == eslOK &&
665 esl_vec_DCompare(d1->alpha[i], d2->alpha[j], d1->K, tol) == eslOK)
666 A[i][j] = TRUE;
667
668 if ((status = esl_graph_MaxBipartiteMatch(A, d1->Q, d2->Q, NULL, &nmatch)) != eslOK) goto ERROR;
669
670 status = (nmatch == d1->Q) ? eslOK: eslFAIL;
671 /* fallthrough */
672 ERROR:
673 esl_mat_IDestroy(A);
674 return status;
675 }
676
677
678
679
680 /* Function: esl_mixdchlet_Dump()
681 *
682 * Purpose: Dump the mixture Dirichlet <d>.
683 */
684 int
esl_mixdchlet_Dump(FILE * fp,const ESL_MIXDCHLET * dchl)685 esl_mixdchlet_Dump(FILE *fp, const ESL_MIXDCHLET *dchl)
686 {
687 int k,a; /* counters over mixture components, residues */
688
689 fprintf(fp, "Mixture Dirichlet: Q=%d K=%d\n", dchl->Q, dchl->K);
690 for (k = 0; k < dchl->Q; k++)
691 {
692 fprintf(fp, "q[%d] %f\n", k, dchl->q[k]);
693 for (a = 0; a < dchl->K; a++)
694 fprintf(fp, "alpha[%d][%d] %f\n", k, a, dchl->alpha[k][a]);
695 }
696 return eslOK;
697 }
698
699
700 /*****************************************************************
701 * 6. Unit tests
702 *****************************************************************/
703 #ifdef eslMIXDCHLET_TESTDRIVE
704
705 /* utest_io
706 * Write a mixture out; read it back in; should be the same.
707 */
708 static void
utest_io(ESL_RANDOMNESS * rng)709 utest_io(ESL_RANDOMNESS *rng)
710 {
711 char msg[] = "esl_mixdchlet: io unit test failed";
712 int Q = 1 + esl_rnd_Roll(rng, 4);
713 int K = 1 + esl_rnd_Roll(rng, 4);
714 ESL_MIXDCHLET *d1 = esl_mixdchlet_Create(Q, K);
715 ESL_MIXDCHLET *d2 = NULL;
716 ESL_FILEPARSER *efp = NULL;
717 FILE *fp = NULL;
718 float tol = 1e-3;
719 char tmpfile[16] = "esltmpXXXXXX";
720 int k,a;
721
722 /* Create a random mixture Dirichlet */
723 if (esl_mixdchlet_Sample(rng, d1) != eslOK) esl_fatal(msg);
724
725 /* Truncate values to four digits after decimal;
726 * Write only saves that much
727 */
728 for (k = 0; k < d1->Q; k++)
729 {
730 d1->q[k] = ((int)(d1->q[k] * 1.e4)) / 1.e4;
731 for (a = 0; a < d1->K; a++)
732 d1->alpha[k][a] = ((int)(d1->alpha[k][a] * 1.e4)) / 1.e4;
733 }
734
735 /* Write it to a a named tmpfile. */
736 if (esl_tmpfile_named(tmpfile, &fp) != eslOK) esl_fatal(msg);
737 if (esl_mixdchlet_Write(fp, d1) != eslOK) esl_fatal(msg);
738 fclose(fp);
739
740 /* Read it back in */
741 if ((fp = fopen(tmpfile, "r")) == NULL) esl_fatal(msg);
742 if ((efp = esl_fileparser_Create(fp)) == NULL) esl_fatal(msg);
743 if (esl_mixdchlet_Read(efp, &d2) != eslOK) esl_fatal(msg);
744 esl_fileparser_Destroy(efp);
745 fclose(fp);
746
747 if (esl_mixdchlet_Compare(d1, d2, tol) != eslOK) esl_fatal(msg);
748
749 esl_mixdchlet_Destroy(d2);
750 esl_mixdchlet_Destroy(d1);
751 remove(tmpfile);
752 }
753
754 /* utest_fit
755 * Generate count data from a known mixture Dirichlet, fit a new one,
756 * and make sure they're similar.
757 *
758 * This test can fail stochastically. If <allow_badluck> is FALSE (the
759 * default), it will reseed <rng> to a predetermined seed that always
760 * works (here, 14). Because the <rng> state is changed, test driver
761 * should put any such utests last.
762 *
763 * This test is typically slow (~5-10sec). The special seed 14 is
764 * chosen to be an unusually fast one (~3s).
765 */
766 static void
utest_fit(ESL_RANDOMNESS * rng,int allow_badluck,int be_verbose)767 utest_fit(ESL_RANDOMNESS *rng, int allow_badluck, int be_verbose)
768 {
769 char msg[] = "esl_mixdchlet: utest_fit failed";
770 int K = 4; // alphabet size
771 int N = 10000; // number of count vectors to generate
772 int nct = 1000; // number of counts per vector
773 ESL_MIXDCHLET *d0 = esl_mixdchlet_Create(2, K); // true 2-component mixture Dirichlet (data generated from this)
774 ESL_MIXDCHLET *dchl = esl_mixdchlet_Create(2, K); // estimated 2-component mixture Dirichlet
775 double *p = malloc(sizeof(double) * K);
776 double **c = esl_mat_DCreate(N, K);
777 double nll0, nll;
778 int i,k,a;
779
780 /* Suppress bad luck by default by fixing the RNG seed */
781 if (! allow_badluck) esl_randomness_Init(rng, 14);
782
783 /* Create known 2-component mixture Dirichlet */
784 d0->q[0] = 0.7;
785 d0->q[1] = 0.3;
786 esl_vec_DSet(d0->alpha[0], d0->K, 1.0); // component 0 = uniform
787 esl_vec_DSet(d0->alpha[1], d0->K, 10.0); // component 1 = mode at 1/K
788
789 /* Sample <N> observed count vectors, given d0 */
790 nll0 = 0;
791 for (i = 0; i < N; i++)
792 {
793 esl_vec_DSet(c[i], d0->K, 0.);
794 k = esl_rnd_DChoose(rng, d0->q, d0->Q); // choose a mixture component
795 esl_dirichlet_DSample(rng, d0->alpha[k], d0->K, p); // sample a pvector
796 for (a = 0; a < nct; a++)
797 c[i][ esl_rnd_DChoose(rng, p, d0->K) ] += 1.0; // sample count vector
798 nll0 -= esl_mixdchlet_logp_c(d0, c[i]);
799 }
800
801 if ( esl_mixdchlet_Sample(rng, dchl) != eslOK) esl_fatal(msg);
802 if ( esl_mixdchlet_Fit(c, N, dchl, &nll) != eslOK) esl_fatal(msg);
803
804 if (be_verbose)
805 {
806 printf("True (nll=%10.4g):\n", nll0); esl_mixdchlet_Dump(stdout, d0);
807 printf("Inferred (nll=%10.4g):\n", nll); esl_mixdchlet_Dump(stdout, dchl);
808 }
809 if ( esl_mixdchlet_Compare(d0, dchl, 0.1) != eslOK) esl_fatal(msg);
810 if ( nll0 < nll ) esl_fatal(msg);
811
812 esl_mat_DDestroy(c);
813 esl_mixdchlet_Destroy(dchl);
814 esl_mixdchlet_Destroy(d0);
815 free(p);
816 }
817 #endif // eslMIXDCHLET_TESTDRIVE
818
819
820 /*****************************************************************
821 * 7. Test driver
822 *****************************************************************/
823 #ifdef eslMIXDCHLET_TESTDRIVE
824
825 #include "easel.h"
826 #include "esl_fileparser.h"
827 #include "esl_getopts.h"
828 #include "esl_random.h"
829 #include "esl_dirichlet.h"
830
831 static ESL_OPTIONS options[] = {
832 /* name type default env range toggles reqs incomp help docgroup*/
833 { "-h", eslARG_NONE, FALSE, NULL, NULL, NULL, NULL, NULL, "show brief help on version and usage", 0 },
834 { "-s", eslARG_INT, "0", NULL, NULL, NULL, NULL, NULL, "set random number seed to <n>", 0 },
835 { "-x", eslARG_NONE, FALSE, NULL, NULL, NULL, NULL, NULL, "allow bad luck (expected stochastic failures)", 0 },
836 { "-v", eslARG_NONE, FALSE, NULL, NULL, NULL, NULL, NULL, "be more verbose" , 0 },
837 { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
838 };
839 static char usage[] = "[-options]";
840 static char banner[] = "test driver for mixdchlet module";
841
842 int
main(int argc,char ** argv)843 main(int argc, char **argv)
844 {
845 ESL_GETOPTS *go = esl_getopts_CreateDefaultApp(options, 0, argc, argv, banner, usage);
846 ESL_RANDOMNESS *rng = esl_randomness_Create(esl_opt_GetInteger(go, "-s"));
847 int be_verbose = esl_opt_GetBoolean(go, "-v");
848 int allow_badluck = esl_opt_GetBoolean(go, "-x"); // if a utest can fail just by chance, let it, instead of suppressing
849
850 fprintf(stderr, "## %s\n", argv[0]);
851 fprintf(stderr, "# rng seed = %" PRIu32 "\n", esl_randomness_GetSeed(rng));
852
853 utest_io (rng);
854
855 // Tests that can fail stochastically go last, because they reset the RNG seed by default.
856 utest_fit(rng, allow_badluck, be_verbose);
857
858 fprintf(stderr, "# status = ok\n");
859
860 esl_randomness_Destroy(rng);
861 esl_getopts_Destroy(go);
862 return eslOK;
863 }
864 #endif /*eslMIXDCHLET_TESTDRIVE*/
865 /*--------------------- end, test driver ------------------------*/
866
867
868 /*****************************************************************
869 * x. Example.
870 *****************************************************************/
871 #ifdef eslMIXDCHLET_EXAMPLE
872
873 #include "easel.h"
874 #include "esl_fileparser.h"
875 #include "esl_getopts.h"
876
877 static ESL_OPTIONS options[] = {
878 /* name type default env range toggles reqs incomp help docgroup*/
879 { "-h", eslARG_NONE, FALSE, NULL, NULL, NULL, NULL, NULL, "show brief help on version and usage", 0 },
880 { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
881 };
882 static char usage[] = "[-options] <mixdchlet_file> <counts_file>";
883 static char banner[] = "example driver for mixdchlet module: log likelihood of count data";
884
885 int
main(int argc,char ** argv)886 main(int argc, char **argv)
887 {
888 ESL_GETOPTS *go = esl_getopts_CreateDefaultApp(options, 2, argc, argv, banner, usage);
889 char *dfile = esl_opt_GetArg(go, 1);
890 char *ctfile = esl_opt_GetArg(go, 2);
891 ESL_FILEPARSER *efp = NULL;
892 ESL_MIXDCHLET *dchl = NULL;
893 double *ct = NULL; // one countvector read from ctfile at a time
894 char *tok = NULL;
895 int toklen = 0;
896 int a;
897 double nll = 0;
898 int status;
899
900 /* Read mixture Dirichlet */
901 if ( esl_fileparser_Open(dfile, NULL, &efp) != eslOK) esl_fatal("failed to open %s for reading", dfile);
902 esl_fileparser_SetCommentChar(efp, '#');
903 if ( esl_mixdchlet_Read(efp, &dchl) != eslOK) esl_fatal("failed to parse %s\n %s", dfile, efp->errbuf);
904 esl_fileparser_Close(efp);
905 efp = NULL;
906
907 /* Read count vectors one at a time, increment nll */
908 if ( esl_fileparser_Open(ctfile, NULL, &efp) != eslOK) esl_fatal("failed to open %s for reading", ctfile);
909 esl_fileparser_SetCommentChar(efp, '#');
910 ct = malloc(sizeof(double) * dchl->K);
911 while ((status = esl_fileparser_NextLine(efp)) == eslOK)
912 {
913 a = 0; // counter over fields on line, ct[a=0..K-1].
914 while ((status = esl_fileparser_GetTokenOnLine(efp, &tok, &toklen)) == eslOK)
915 {
916 if (a == dchl->K) esl_fatal("parse failed, %s:%d: > K=%d fields on line", ctfile, efp->linenumber, dchl->K);
917 if (! esl_str_IsReal(tok)) esl_fatal("parse failed, %s:%d: field %d (%s) not a real number", ctfile, efp->linenumber, a+1, tok);
918 ct[a++] = atof(tok);
919 }
920
921 nll += esl_mixdchlet_logp_c(dchl, ct);
922 }
923 esl_fileparser_Close(efp);
924
925 printf("nll = %g\n", -nll);
926
927 free(ct);
928 esl_mixdchlet_Destroy(dchl);
929 esl_getopts_Destroy(go);
930 }
931 #endif /*eslMIXDCHLET_EXAMPLE*/
932
933
934
935
936