1 /*  vcfsom.c -- SOM (Self-Organizing Map) filtering.
2 
3     Copyright (C) 2013-2014, 2020 Genome Research Ltd.
4 
5     Author: Petr Danecek <pd3@sanger.ac.uk>
6 
7 Permission is hereby granted, free of charge, to any person obtaining a copy
8 of this software and associated documentation files (the "Software"), to deal
9 in the Software without restriction, including without limitation the rights
10 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 copies of the Software, and to permit persons to whom the Software is
12 furnished to do so, subject to the following conditions:
13 
14 The above copyright notice and this permission notice shall be included in
15 all copies or substantial portions of the Software.
16 
17 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 THE SOFTWARE.  */
24 
25 #include <stdio.h>
26 #include <unistd.h>
27 #include <getopt.h>
28 #include <assert.h>
29 #include <ctype.h>
30 #include <string.h>
31 #include <errno.h>
32 #include <sys/stat.h>
33 #include <sys/types.h>
34 #include <math.h>
35 #include <time.h>
36 #include <htslib/vcf.h>
37 #include <htslib/synced_bcf_reader.h>
38 #include <htslib/vcfutils.h>
39 #include <htslib/hts_os.h>
40 #include <inttypes.h>
41 #include "bcftools.h"
42 
43 #define SOM_TRAIN    1
44 #define SOM_CLASSIFY 2
45 
46 typedef struct
47 {
48     int ndim;       // dimension of the map (2D, 3D, ...)
49     int nbin;       // number of bins in th map
50     int size;       // pow(nbin,ndim)
51     int kdim;       // dimension of the input vectors
52     int nt, t;      // total number of learning cycles and the current cycle
53     double *w, *c;  // weights and counts (sum of learning influence)
54     double learn;   // learning rate
55     double bmu_th;  // best-matching unit threshold
56     int *a_idx, *b_idx; // temp arrays for traversing variable number of nested loops
57     double *div;        // dtto
58 }
59 som_t;
60 
61 typedef struct
62 {
63     // SOM parameters
64     double bmu_th, learn;
65     int ndim, nbin, ntrain, t;
66     int nfold;                  // n-fold cross validation = the number of SOMs
67     som_t **som;
68 
69     // annots reader's data
70     htsFile *file;              // reader
71     kstring_t str;              // temporary string for the reader
72     int dclass, mvals;
73     double *vals;
74 
75     // training data
76     double *train_dat;
77     int *train_class, mtrain_class, mtrain_dat;
78 
79     int rand_seed, good_class, bad_class;
80     char **argv, *fname, *prefix;
81     int argc, action, train_bad, merge;
82 }
83 args_t;
84 
85 static void usage(void);
86 FILE *open_file(char **fname, const char *mode, const char *fmt, ...);
87 void mkdir_p(const char *fmt, ...);
88 
msprintf(const char * fmt,...)89 char *msprintf(const char *fmt, ...)
90 {
91     va_list ap;
92     va_start(ap, fmt);
93     int n = vsnprintf(NULL, 0, fmt, ap) + 2;
94     va_end(ap);
95 
96     char *str = (char*)malloc(n);
97     va_start(ap, fmt);
98     vsnprintf(str, n, fmt, ap);
99     va_end(ap);
100 
101     return str;
102 }
103 
104 /*
105  *  char *t, *p = str;
106  *  t = column_next(p, '\t');
107  *  if ( strlen("<something>")==t-p && !strncmp(p,"<something>",t-p) ) printf("found!\n");
108  *
109  *  char *t;
110  *  t = column_next(str, '\t'); if ( !*t ) error("expected field\n", str);
111  *  t = column_next(t+1, '\t'); if ( !*t ) error("expected field\n", str);
112  */
column_next(char * start,char delim)113 static inline char *column_next(char *start, char delim)
114 {
115     char *end = start;
116     while (*end && *end!=delim) end++;
117     return end;
118 }
119 /**
120  *  annots_reader_next() - reads next line from annots.tab.gz and sets: class, vals
121  *   Returns 1 on successful read or 0 if no further record could be read.
122  */
annots_reader_next(args_t * args)123 int annots_reader_next(args_t *args)
124 {
125     args->str.l = 0;
126     if ( hts_getline(args->file,'\n',&args->str)<=0 ) return 0;
127 
128     char *t, *line = args->str.s;
129 
130     if ( !args->mvals )
131     {
132         t = line;
133         while ( *t )
134         {
135             if ( *t=='\t' ) args->mvals++;
136             t++;
137         }
138         args->vals = (double*) malloc(args->mvals*sizeof(double));
139     }
140 
141     // class
142     args->dclass = atoi(line);
143     t = column_next(line, '\t');
144 
145     // values
146     int i;
147     for (i=0; i<args->mvals; i++)
148     {
149         if ( !*t ) error("Could not parse %d-th data field: is the line truncated?\nThe line was: [%s]\n",i+2,line);
150         args->vals[i] = atof(++t);
151         t = column_next(t,'\t');
152     }
153     return 1;
154 }
annots_reader_reset(args_t * args)155 void annots_reader_reset(args_t *args)
156 {
157     if ( args->file ) hts_close(args->file);
158     if ( !args->fname ) error("annots_reader_reset: no fname\n");
159     args->file = hts_open(args->fname, "r");
160 }
annots_reader_close(args_t * args)161 void annots_reader_close(args_t *args)
162 {
163     hts_close(args->file);
164 }
165 
som_write_map(char * prefix,som_t ** som,int nsom)166 static void som_write_map(char *prefix, som_t **som, int nsom)
167 {
168     FILE *fp = open_file(NULL,"w","%s.som",prefix);
169     size_t nw;
170     if ( (nw=fwrite("SOMv1",5,1,fp))!=5 ) error("Failed to write 5 bytes\n");
171     if ( (nw=fwrite(&nsom,sizeof(int),1,fp))!=sizeof(int) ) error("Failed to write %zu bytes\n",sizeof(int));
172     int i;
173     for (i=0; i<nsom; i++)
174     {
175         if ( (nw=fwrite(&som[i]->size,sizeof(int),1,fp))!=sizeof(int) ) error("Failed to write %zu bytes\n",sizeof(int));
176         if ( (nw=fwrite(&som[i]->kdim,sizeof(int),1,fp))!=sizeof(int) ) error("Failed to write %zu bytes\n",sizeof(int));
177         if ( (nw=fwrite(som[i]->w,sizeof(double),som[i]->size*som[i]->kdim,fp))!=sizeof(double)*som[i]->size*som[i]->kdim ) error("Failed to write %zu bytes\n",sizeof(double)*som[i]->size*som[i]->kdim);
178         if ( (nw=fwrite(som[i]->c,sizeof(double),som[i]->size,fp))!=sizeof(double)*som[i]->size ) error("Failed to write %zu bytes\n",sizeof(double)*som[i]->size);
179     }
180     if ( fclose(fp) ) error("%s.som: fclose failed\n",prefix);
181 }
som_load_map(char * prefix,int * nsom)182 static som_t** som_load_map(char *prefix, int *nsom)
183 {
184     FILE *fp = open_file(NULL,"r","%s.som",prefix);
185     char buf[5];
186     if ( fread(buf,5,1,fp)!=1 || strncmp(buf,"SOMv1",5) ) error("Could not parse %s.som\n", prefix);
187 
188     if ( fread(nsom,sizeof(int),1,fp)!=1 ) error("Could not read %s.som\n", prefix);
189     som_t **som = (som_t**)malloc(*nsom*sizeof(som_t*));
190 
191     int i;
192     for (i=0; i<*nsom; i++)
193     {
194         som[i] = (som_t*) calloc(1,sizeof(som_t));
195         if ( fread(&som[i]->size,sizeof(int),1,fp) != 1 ) error("Could not read %s.som\n", prefix);
196         if ( fread(&som[i]->kdim,sizeof(int),1,fp) != 1 ) error("Could not read %s.som\n", prefix);
197         som[i]->w = (double*) malloc(sizeof(double)*som[i]->size*som[i]->kdim);
198         som[i]->c = (double*) malloc(sizeof(double)*som[i]->size);
199         if ( fread(som[i]->w,sizeof(double),som[i]->size*som[i]->kdim,fp) != som[i]->size*som[i]->kdim ) error("Could not read from %s.som\n", prefix);
200         if ( fread(som[i]->c,sizeof(double),som[i]->size,fp) != som[i]->size ) error("Could not read from %s.som\n", prefix);
201     }
202     if ( fclose(fp) ) error("%s.som: fclose failed\n",prefix);
203     return som;
204 }
som_create_plot(som_t * som,char * prefix)205 static void som_create_plot(som_t *som, char *prefix)
206 {
207     if ( som->ndim!=2 ) return;
208 
209     char *fname;
210     FILE *fp = open_file(&fname,"w","%s.py",prefix);
211     fprintf(fp,
212             "import matplotlib as mpl\n"
213             "mpl.use('Agg')\n"
214             "import matplotlib.pyplot as plt\n"
215             "\n"
216             "dat = [\n"
217            );
218     int i,j;
219     double *val = som->c;
220     for (i=0; i<som->nbin; i++)
221     {
222         fprintf(fp,"[");
223         for (j=0; j<som->nbin; j++)
224         {
225             if ( j>0 ) fprintf(fp,",");
226             fprintf(fp,"%e", *val);
227             val++;
228         }
229         fprintf(fp,"],\n");
230     }
231     fprintf(fp,
232             "]\n"
233             "fig = plt.figure()\n"
234             "ax1 = plt.subplot(111)\n"
235             "im1 = ax1.imshow(dat)\n"
236             "fig.colorbar(im1)\n"
237             "plt.savefig('%s.png')\n"
238             "plt.close()\n"
239             "\n", prefix
240            );
241     fclose(fp);
242     free(fname);
243 }
244 // Find the best matching unit: the node with minimum distance from the input vector
som_find_bmu(som_t * som,double * vec,double * dist)245 static inline int som_find_bmu(som_t *som, double *vec, double *dist)
246 {
247     double *ptr = som->w;
248     double min_dist = HUGE_VAL;
249     int min_idx = 0;
250 
251     int i, k;
252     for (i=0; i<som->size; i++)
253     {
254         double dist = 0;
255         for (k=0; k<som->kdim; k++)
256             dist += (vec[k] - ptr[k]) * (vec[k] - ptr[k]);
257         if ( dist < min_dist )
258         {
259             min_dist = dist;
260             min_idx  = i;
261         }
262         ptr += som->kdim;
263     }
264 
265     if ( dist ) *dist = min_dist;
266     return min_idx;
267 }
som_get_score(som_t * som,double * vec,double bmu_th)268 static inline double som_get_score(som_t *som, double *vec, double bmu_th)
269 {
270     double *ptr = som->w;
271     double min_dist = HUGE_VAL;
272 
273     int i, k;
274     for (i=0; i<som->size; i++)
275     {
276         if ( som->c[i] >= bmu_th )
277         {
278             double dist = 0;
279             for (k=0; k<som->kdim; k++)
280                 dist += (vec[k] - ptr[k]) * (vec[k] - ptr[k]);
281             if ( dist < min_dist ) min_dist = dist;
282         }
283         ptr += som->kdim;
284     }
285     return sqrt(min_dist);
286 }
287 // Convert flat index to that of a k-dimensional cube
som_idx_to_ndim(som_t * som,int idx,int * ndim)288 static inline void som_idx_to_ndim(som_t *som, int idx, int *ndim)
289 {
290     int i;
291     double sub = 0;
292 
293     ndim[0] = idx/som->div[0];
294     for (i=1; i<som->ndim; i++)
295     {
296         sub += ndim[i-1] * som->div[i-1];
297         ndim[i] = (idx - sub)/som->div[i];
298     }
299 }
som_train_site(som_t * som,double * vec,int update_counts)300 static void som_train_site(som_t *som, double *vec, int update_counts)
301 {
302     // update learning rate and learning radius
303     som->t++;
304     double dt = exp(-som->t/som->nt);
305     double learning_rate = som->learn * dt;
306     double radius = som->nbin * dt; radius *= radius;
307 
308     // find the best matching unit and its indexes
309     int min_idx = som_find_bmu(som, vec, NULL);
310     som_idx_to_ndim(som, min_idx, som->a_idx);
311 
312     // update the weights: traverse the map and make all nodes within the
313     // radius more similar to the input vector
314     double *ptr = som->w;
315     double *cnt = som->c;
316     int i, j, k;
317     for (i=0; i<som->size; i++)
318     {
319         som_idx_to_ndim(som, i, som->b_idx);
320         double dist = 0;
321         for (j=0; j<som->ndim; j++)
322             dist += (som->a_idx[j] - som->b_idx[j]) * (som->a_idx[j] - som->b_idx[j]);
323         if ( dist <= radius )
324         {
325             double influence = exp(-dist*dist*0.5/radius) * learning_rate;
326             for (k=0; k<som->kdim; k++)
327                 ptr[k] += influence * (vec[k] - ptr[k]);
328 
329             // Bad sites may help to shape the map, but only nodes with big enough
330             // influence will be used for classification.
331             if ( update_counts ) *cnt += influence;
332         }
333         ptr += som->kdim;
334         cnt++;
335     }
336 }
som_norm_counts(som_t * som)337 static void som_norm_counts(som_t *som)
338 {
339     int i;
340     double max = 0;
341     for (i=0; i<som->size; i++)
342         if ( max < som->c[i] ) max = som->c[i];
343     for (i=0; i<som->size; i++)
344         som->c[i] /= max;
345 }
som_init(args_t * args)346 static som_t *som_init(args_t *args)
347 {
348     som_t *som  = (som_t*) calloc(1,sizeof(som_t));
349     som->ndim   = args->ndim;
350     som->nbin   = args->nbin;
351     som->kdim   = args->mvals;
352     som->nt     = args->ntrain;
353     som->learn  = args->learn;
354     som->bmu_th = args->bmu_th;
355     som->size   = pow(som->nbin,som->ndim);
356     som->w = (double*) malloc(sizeof(double)*som->size*som->kdim);
357     if ( !som->w ) error("Could not alloc %"PRIu64" bytes [nbin=%d ndim=%d kdim=%d]\n", (uint64_t)(sizeof(double)*som->size*som->kdim),som->nbin,som->ndim,som->kdim);
358     som->c = (double*) calloc(som->size,sizeof(double));
359     if ( !som->w ) error("Could not alloc %"PRIu64" bytes [nbin=%d ndim=%d]\n", (uint64_t)(sizeof(double)*som->size),som->nbin,som->ndim);
360     int i;
361     for (i=0; i<som->size*som->kdim; i++)
362         som->w[i] = random();
363     som->a_idx = (int*) malloc(sizeof(int)*som->ndim);
364     som->b_idx = (int*) malloc(sizeof(int)*som->ndim);
365     som->div   = (double*) malloc(sizeof(double)*som->ndim);
366     for (i=0; i<som->ndim; i++)
367         som->div[i] = pow(som->nbin,som->ndim-i-1);
368     return som;
369 }
som_destroy(som_t * som)370 static void som_destroy(som_t *som)
371 {
372     free(som->a_idx); free(som->b_idx); free(som->div);
373     free(som->w); free(som->c);
374     free(som);
375 }
376 
init_data(args_t * args)377 static void init_data(args_t *args)
378 {
379     // Get first line to learn the vector size
380     annots_reader_reset(args);
381     annots_reader_next(args);
382 
383     if ( args->action==SOM_CLASSIFY )
384         args->som = som_load_map(args->prefix,&args->nfold);
385 }
destroy_data(args_t * args)386 static void destroy_data(args_t *args)
387 {
388     int i;
389     if ( args->som )
390     {
391         for (i=0; i<args->nfold; i++) som_destroy(args->som[i]);
392     }
393     free(args->train_dat);
394     free(args->train_class);
395     free(args->som);
396     free(args->vals);
397     free(args->str.s);
398 }
399 
400 #define MERGE_MIN 0
401 #define MERGE_MAX 1
402 #define MERGE_AVG 2
get_min_score(args_t * args,int iskip)403 static double get_min_score(args_t *args, int iskip)
404 {
405     int i;
406     double score, min_score = HUGE_VAL;
407     for (i=0; i<args->nfold; i++)
408     {
409         if ( i==iskip ) continue;
410         score = som_get_score(args->som[i], args->vals, args->bmu_th);
411         if ( i==0 || score < min_score ) min_score = score;
412     }
413     return min_score;
414 }
get_max_score(args_t * args,int iskip)415 static double get_max_score(args_t *args, int iskip)
416 {
417     int i;
418     double score, max_score = -HUGE_VAL;
419     for (i=0; i<args->nfold; i++)
420     {
421         if ( i==iskip ) continue;
422         score = som_get_score(args->som[i], args->vals, args->bmu_th);
423         if ( i==0 || max_score < score ) max_score = score;
424     }
425     return max_score;
426 }
get_avg_score(args_t * args,int iskip)427 static double get_avg_score(args_t *args, int iskip)
428 {
429     int i, n = 0;
430     double score = 0;
431     for (i=0; i<args->nfold; i++)
432     {
433         if ( i==iskip ) continue;
434         score += som_get_score(args->som[i], args->vals, args->bmu_th);
435         n++;
436     }
437     return score/n;
438 }
cmpfloat_desc(const void * a,const void * b)439 static int cmpfloat_desc(const void *a, const void *b)
440 {
441     float fa = *((float*)a);
442     float fb = *((float*)b);
443     if ( fa<fb ) return 1;
444     if ( fa>fb ) return -1;
445     return 0;
446 }
447 
create_eval_plot(args_t * args)448 static void create_eval_plot(args_t *args)
449 {
450     FILE *fp = open_file(NULL,"w","%s.eval.py", args->prefix);
451     fprintf(fp,
452             "import matplotlib as mpl\n"
453             "mpl.use('Agg')\n"
454             "import matplotlib.pyplot as plt\n"
455             "\n"
456             "import csv\n"
457             "csv.register_dialect('tab', delimiter='\\t', quoting=csv.QUOTE_NONE)\n"
458             "dat = []\n"
459             "with open('%s.eval', 'r') as f:\n"
460             "\treader = csv.reader(f, 'tab')\n"
461             "\tfor row in reader:\n"
462             "\t\tif row[0][0]!='#': dat.append(row)\n"
463             "\n"
464             "fig = plt.figure()\n"
465             "ax1 = plt.subplot(111)\n"
466             "ax1.plot([x[0] for x in dat],[x[1] for x in dat],'g',label='Good')\n"
467             "ax1.plot([x[0] for x in dat],[x[2] for x in dat],'r',label='Bad')\n"
468             "ax1.set_xlabel('SOM score')\n"
469             "ax1.set_ylabel('Number of training sites')\n"
470             "ax1.legend(loc='best',prop={'size':8},frameon=False)\n"
471             "plt.savefig('%s.eval.png')\n"
472             "plt.close()\n"
473             "\n", args->prefix,args->prefix
474            );
475     fclose(fp);
476 }
477 
do_train(args_t * args)478 static void do_train(args_t *args)
479 {
480     // read training sites
481     int i, igood = 0, ibad = 0, ngood = 0, nbad = 0, ntrain = 0;
482     annots_reader_reset(args);
483     while ( annots_reader_next(args) )
484     {
485         // determine which of the nfold's SOMs to train
486         int isom = 0;
487         if ( args->dclass == args->good_class )
488         {
489             if ( ++igood >= args->nfold ) igood = 0;
490             isom = igood;
491             ngood++;
492         }
493         else if ( args->dclass == args->bad_class )
494         {
495             if ( ++ibad >= args->nfold ) ibad = 0;
496             isom = ibad;
497             nbad++;
498         }
499         else
500             error("Could not determine the class: %d (vs %d and %d)\n", args->dclass,args->good_class,args->bad_class);
501 
502         // save the values for evaluation
503         ntrain++;
504         hts_expand(double, ntrain*args->mvals, args->mtrain_dat, args->train_dat);
505         hts_expand(int, ntrain, args->mtrain_class, args->train_class);
506         memcpy(args->train_dat+(ntrain-1)*args->mvals, args->vals, args->mvals*sizeof(double));
507         args->train_class[ntrain-1] = (args->dclass==args->good_class ? 1 : 0) | isom<<1;  // store class + chunk used for training
508     }
509     annots_reader_close(args);
510 
511     // init maps
512     if ( !args->ntrain ) args->ntrain = ngood/args->nfold;
513     srandom(args->rand_seed);
514     args->som = (som_t**) malloc(sizeof(som_t*)*args->nfold);
515     for (i=0; i<args->nfold; i++) args->som[i] = som_init(args);
516 
517     // train
518     for (i=0; i<ntrain; i++)
519     {
520         int is_good = args->train_class[i] & 1;
521         int isom    = args->train_class[i] >> 1;
522         if ( is_good || args->train_bad )
523             som_train_site(args->som[isom], args->train_dat+i*args->mvals, is_good);
524     }
525 
526     // norm and create plots
527     for (i=0; i<args->nfold; i++)
528     {
529         som_norm_counts(args->som[i]);
530         if ( args->prefix )
531         {
532             char *bname = msprintf("%s.som.%d", args->prefix,i);
533             som_create_plot(args->som[i], bname);
534             free(bname);
535         }
536     }
537 
538     // evaluate
539     float *good = (float*) malloc(sizeof(float)*ngood); assert(good);
540     float *bad  = (float*) malloc(sizeof(float)*nbad); assert(bad);
541     igood = ibad = 0;
542     double max_score = sqrt(args->som[0]->kdim);
543     for (i=0; i<ntrain; i++)
544     {
545         double score = 0;
546         int is_good = args->train_class[i] & 1;
547         int isom    = args->train_class[i] >> 1;    // this vector was used for training isom-th SOM, skip
548         if ( args->nfold==1 ) isom = -1;
549         memcpy(args->vals, args->train_dat+i*args->mvals, args->mvals*sizeof(double));
550         switch (args->merge)
551         {
552             case MERGE_MIN: score = get_min_score(args, isom); break;
553             case MERGE_MAX: score = get_max_score(args, isom); break;
554             case MERGE_AVG: score = get_avg_score(args, isom); break;
555         }
556         score = 1.0 - score/max_score;
557         if ( is_good )
558             good[igood++] = score;
559         else
560             bad[ibad++] = score;
561     }
562     qsort(good, ngood, sizeof(float), cmpfloat_desc);
563     qsort(bad, nbad, sizeof(float), cmpfloat_desc);
564     FILE *fp = NULL;
565     if ( args->prefix ) fp = open_file(NULL,"w","%s.eval", args->prefix);
566     igood = 0;
567     ibad  = 0;
568     float prev_score = good[0]>bad[0] ? good[0] : bad[0];
569     int printed = 0;
570     while ( igood<ngood || ibad<nbad )
571     {
572         if ( igood<ngood && good[igood]==prev_score ) { igood++; continue; }
573         if ( ibad<nbad && bad[ibad]==prev_score ) { ibad++; continue; }
574         if ( fp )
575             fprintf(fp,"%e\t%f\t%f\n", prev_score, (float)igood/ngood, (float)ibad/nbad);
576         if ( !printed && (float)igood/ngood > 0.9 )
577         {
578             printf("%.2f\t%.2f\t%e\t# %% of bad [1] and good [2] sites at a cutoff [3]\n", 100.*ibad/nbad,100.*igood/ngood,prev_score);
579             printed = 1;
580         }
581 
582         if ( igood<ngood && ibad<nbad ) prev_score = good[igood]>bad[ibad] ? good[igood] : bad[ibad];
583         else if ( igood<ngood ) prev_score = good[igood];
584         else prev_score = bad[ibad];
585     }
586     if ( !printed ) printf("%.2f\t%.2f\t%e\t# %% of bad [1] and good [2] sites at a cutoff [3]\n", 100.*ibad/nbad,100.*igood/ngood,prev_score);
587     if ( fp )
588     {
589         if ( fclose(fp) ) error("%s.eval: fclose failed: %s\n",args->prefix,strerror(errno));
590         create_eval_plot(args);
591         som_write_map(args->prefix, args->som, args->nfold);
592     }
593 
594     free(good);
595     free(bad);
596 }
597 
do_classify(args_t * args)598 static void do_classify(args_t *args)
599 {
600     annots_reader_reset(args);
601     double max_score = sqrt(args->som[0]->kdim);
602     while ( annots_reader_next(args) )
603     {
604         double score = 0;
605         switch (args->merge)
606         {
607             case MERGE_MIN: score = get_min_score(args, -1); break;
608             case MERGE_MAX: score = get_max_score(args, -1); break;
609             case MERGE_AVG: score = get_avg_score(args, -1); break;
610         }
611         printf("%e\n", 1.0 - score/max_score);
612     }
613     annots_reader_close(args);
614 }
615 
usage(void)616 static void usage(void)
617 {
618     fprintf(stderr, "\n");
619     fprintf(stderr, "About:   SOM (Self-Organizing Map) filtering.\n");
620     fprintf(stderr, "Usage:   bcftools som --train    [options] <annots.tab.gz>\n");
621     fprintf(stderr, "         bcftools som --classify [options]\n");
622     fprintf(stderr, "\n");
623     fprintf(stderr, "Model training options:\n");
624     fprintf(stderr, "    -f, --nfold <int>                  n-fold cross-validation (number of maps) [5]\n");
625     fprintf(stderr, "    -p, --prefix <string>              prefix of output files\n");
626     fprintf(stderr, "    -s, --size <int>                   map size [20]\n");
627     fprintf(stderr, "    -t, --train                        \n");
628     fprintf(stderr, "\n");
629     fprintf(stderr, "Classifying options:\n");
630     fprintf(stderr, "    -c, --classify                     \n");
631     fprintf(stderr, "\n");
632     fprintf(stderr, "Experimental training options (no reason to change):\n");
633     fprintf(stderr, "    -b, --bmu-threshold <float>        threshold for selection of best-matching unit [0.9]\n");
634     fprintf(stderr, "    -d, --som-dimension <int>          SOM dimension [2]\n");
635     fprintf(stderr, "    -e, --exclude-bad                  exclude bad sites from training, use for evaluation only\n");
636     fprintf(stderr, "    -l, --learning-rate <float>        learning rate [1.0]\n");
637     fprintf(stderr, "    -m, --merge <min|max|avg>          -f merge algorithm [avg]\n");
638     fprintf(stderr, "    -n, --ntrain-sites <int>           effective number of training sites [number of good sites]\n");
639     fprintf(stderr, "    -r, --random-seed <int>            random seed, 0 for time() [1]\n");
640     fprintf(stderr, "\n");
641     exit(1);
642 }
643 
main_vcfsom(int argc,char * argv[])644 int main_vcfsom(int argc, char *argv[])
645 {
646     int c;
647     args_t *args     = (args_t*) calloc(1,sizeof(args_t));
648     args->argc       = argc; args->argv = argv;
649     args->nbin       = 20;
650     args->learn      = 1.0;
651     args->bmu_th     = 0.9;
652     args->nfold      = 5;
653     args->rand_seed  = 1;
654     args->ndim       = 2;
655     args->bad_class  = 1;
656     args->good_class = 2;
657     args->merge      = MERGE_AVG;
658     args->train_bad  = 1;
659 
660     static struct option loptions[] =
661     {
662         {"help",0,0,'h'},
663         {"prefix",1,0,'p'},
664         {"ntrain-sites",1,0,'n'},
665         {"random-seed",1,0,'r'},
666         {"bmu-threshold",1,0,'b'},
667         {"exclude-bad",0,0,'e'},
668         {"learning-rate",1,0,'l'},
669         {"size",1,0,'s'},
670         {"som-dimension",1,0,'d'},
671         {"nfold",1,0,'f'},
672         {"merge",1,0,'m'},
673         {"train",0,0,'t'},
674         {"classify",0,0,'c'},
675         {0,0,0,0}
676     };
677     while ((c = getopt_long(argc, argv, "htcp:n:r:b:l:s:f:d:m:e",loptions,NULL)) >= 0) {
678         switch (c) {
679             case 'e': args->train_bad = 0; break;
680             case 'm':
681                 if ( !strcmp(optarg,"min") ) args->merge = MERGE_MIN;
682                 else if ( !strcmp(optarg,"max") ) args->merge = MERGE_MAX;
683                 else if ( !strcmp(optarg,"avg") ) args->merge = MERGE_AVG;
684                 else error("The -m method not recognised: %s\n", optarg);
685                 break;
686             case 'p': args->prefix = optarg; break;
687             case 'n': args->ntrain = atoi(optarg); break;
688             case 'r': args->rand_seed = atoi(optarg); break;
689             case 'b': args->bmu_th = atof(optarg); break;
690             case 'l': args->learn = atof(optarg); break;
691             case 's': args->nbin = atoi(optarg); break;
692             case 'f': args->nfold = atoi(optarg); break;
693             case 'd':
694                 args->ndim = atoi(optarg);
695                 if ( args->ndim<2 ) error("Expected -d >=2, got %d\n", args->ndim);
696                 if ( args->ndim>3 ) fprintf(stderr,"Warning: This will take a long time and is not going to make the results better: -d %d\n", args->ndim);
697                 break;
698             case 't': args->action = SOM_TRAIN; break;
699             case 'c': args->action = SOM_CLASSIFY; break;
700             case 'h':
701             case '?': usage(); break;
702             default: error("Unknown argument: %s\n", optarg);
703         }
704     }
705 
706     if ( !args->rand_seed ) args->rand_seed = time(NULL);
707     if ( argc!=optind+1 ) usage();
708     args->fname = argv[optind];
709     init_data(args);
710 
711     if ( args->action == SOM_TRAIN ) do_train(args);
712     else if ( args->action == SOM_CLASSIFY ) do_classify(args);
713 
714     destroy_data(args);
715     free(args);
716     return 0;
717 }
718 
719