1 /***************************************************************************
2  *   Copyright (C) 2009 by BUI Quang Minh   *
3  *   minh.bui@univie.ac.at   *
4  *                                                                         *
5  *   This program is free software; you can redistribute it and/or modify  *
6  *   it under the terms of the GNU General Public License as published by  *
7  *   the Free Software Foundation; either version 2 of the License, or     *
8  *   (at your option) any later version.                                   *
9  *                                                                         *
10  *   This program is distributed in the hope that it will be useful,       *
11  *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
12  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
13  *   GNU General Public License for more details.                          *
14  *                                                                         *
15  *   You should have received a copy of the GNU General Public License     *
16  *   along with this program; if not, write to the                         *
17  *   Free Software Foundation, Inc.,                                       *
18  *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
19  ***************************************************************************/
20 #include "rateinvar.h"
21 #include "modelfactory.h"
22 #include "rategamma.h"
23 #include "rategammainvar.h"
24 #include "modelmarkov.h"
25 #include "modelliemarkov.h"
26 #include "modeldna.h"
27 #include "modelprotein.h"
28 #include "modelbin.h"
29 #include "modelcodon.h"
30 #include "modelmorphology.h"
31 #include "modelpomo.h"
32 #include "modelset.h"
33 #include "modelmixture.h"
34 #include "ratemeyerhaeseler.h"
35 #include "ratemeyerdiscrete.h"
36 #include "ratekategory.h"
37 #include "ratefree.h"
38 #include "ratefreeinvar.h"
39 #include "rateheterotachy.h"
40 #include "rateheterotachyinvar.h"
41 //#include "ngs.h"
42 #include <string>
43 #include "utils/timeutil.h"
44 #include "nclextra/myreader.h"
45 #include <sstream>
46 
findSubStr(string & name,string sub1,string sub2)47 string::size_type findSubStr(string &name, string sub1, string sub2) {
48     string::size_type pos1, pos2;
49     for (pos1 = 0; pos1 != string::npos; pos1++) {
50         pos1 = name.find(sub1, pos1);
51         if (pos1 == string::npos)
52             break;
53         if (pos1+2 >= name.length() || !isalpha(name[pos1+2])) {
54             break;
55         }
56     }
57 
58     for (pos2 = 0; pos2 != string::npos; pos2++) {
59         pos2 = name.find(sub2, pos2);
60         if (pos2 == string::npos)
61             break;
62         if (pos2+2 >= name.length() ||!isalpha(name[pos2+2]))
63             break;
64     }
65 
66     if (pos1 != string::npos && pos2 != string::npos) {
67         return min(pos1, pos2);
68     } else if (pos1 != string::npos)
69         return pos1;
70     else
71         return pos2;
72 }
73 
posRateHeterotachy(string model_name)74 string::size_type posRateHeterotachy(string model_name) {
75     return findSubStr(model_name, "+H", "*H");
76 }
77 
posRateFree(string & model_name)78 string::size_type posRateFree(string &model_name) {
79     return findSubStr(model_name, "+R", "*R");
80 }
81 
posPOMO(string & model_name)82 string::size_type posPOMO(string &model_name) {
83     return findSubStr(model_name, "+P", "*P");
84 }
85 
readModelsDefinition(Params & params)86 ModelsBlock *readModelsDefinition(Params &params) {
87 
88     ModelsBlock *models_block = new ModelsBlock;
89 
90     try
91     {
92         // loading internal model definitions
93         stringstream in(builtin_mixmodels_definition);
94         ASSERT(in && "stringstream is OK");
95         NxsReader nexus;
96         nexus.Add(models_block);
97         MyToken token(in);
98         nexus.Execute(token);
99     } catch (...) {
100         ASSERT(0 && "predefined mixture models not initialized");
101     }
102 
103     try
104     {
105         // loading internal protei model definitions
106         stringstream in(builtin_prot_models);
107         ASSERT(in && "stringstream is OK");
108         NxsReader nexus;
109         nexus.Add(models_block);
110         MyToken token(in);
111         nexus.Execute(token);
112     } catch (...) {
113         ASSERT(0 && "predefined protein models not initialized");
114     }
115 
116     if (params.model_def_file) {
117         cout << "Reading model definition file " << params.model_def_file << " ... ";
118         MyReader nexus(params.model_def_file);
119         nexus.Add(models_block);
120         MyToken token(nexus.inf);
121         nexus.Execute(token);
122         int num_model = 0, num_freq = 0;
123         for (ModelsBlock::iterator it = models_block->begin(); it != models_block->end(); it++)
124             if (it->second.flag & NM_FREQ) num_freq++; else num_model++;
125         cout << num_model << " models and " << num_freq << " frequency vectors loaded" << endl;
126     }
127     return models_block;
128 }
129 
ModelFactory()130 ModelFactory::ModelFactory() : CheckpointFactory() {
131     model = NULL;
132     site_rate = NULL;
133     store_trans_matrix = false;
134     is_storing = false;
135     joint_optimize = false;
136     fused_mix_rate = false;
137     ASC_type = ASC_NONE;
138 }
139 
findCloseBracket(string & str,size_t start_pos)140 size_t findCloseBracket(string &str, size_t start_pos) {
141     int counter = 0;
142     for (size_t pos = start_pos+1; pos < str.length(); pos++) {
143         if (str[pos] == '{') counter++;
144         if (str[pos] == '}') {
145             if (counter == 0) return pos; else counter--;
146         }
147     }
148     return string::npos;
149 }
150 
ModelFactory(Params & params,string & model_name,PhyloTree * tree,ModelsBlock * models_block)151 ModelFactory::ModelFactory(Params &params, string &model_name, PhyloTree *tree, ModelsBlock *models_block) : CheckpointFactory() {
152     store_trans_matrix = params.store_trans_matrix;
153     is_storing = false;
154     joint_optimize = params.optimize_model_rate_joint;
155     fused_mix_rate = false;
156     ASC_type = ASC_NONE;
157     string model_str = model_name;
158     string rate_str;
159 
160     try {
161 
162 
163     if (model_str == "") {
164         if (tree->aln->seq_type == SEQ_DNA) model_str = "HKY";
165         else if (tree->aln->seq_type == SEQ_PROTEIN) model_str = "LG";
166         else if (tree->aln->seq_type == SEQ_BINARY) model_str = "GTR2";
167         else if (tree->aln->seq_type == SEQ_CODON) model_str = "GY";
168         else if (tree->aln->seq_type == SEQ_MORPH) model_str = "MK";
169         else if (tree->aln->seq_type == SEQ_POMO) model_str = "HKY+P";
170         else model_str = "JC";
171         if (tree->aln->seq_type != SEQ_POMO && !params.model_joint)
172             outWarning("Default model "+model_str + " may be under-fitting. Use option '-m TEST' to determine the best-fit model.");
173     }
174 
175     /********* preprocessing model string ****************/
176     NxsModel *nxsmodel  = NULL;
177 
178     string new_model_str = "";
179     size_t mix_pos;
180     for (mix_pos = 0; mix_pos < model_str.length(); mix_pos++) {
181         size_t next_mix_pos = model_str.find_first_of("+*", mix_pos);
182         string sub_model_str = model_str.substr(mix_pos, next_mix_pos-mix_pos);
183         nxsmodel = models_block->findMixModel(sub_model_str);
184         if (nxsmodel) sub_model_str = nxsmodel->description;
185         new_model_str += sub_model_str;
186         if (next_mix_pos != string::npos)
187             new_model_str += model_str[next_mix_pos];
188         else
189             break;
190         mix_pos = next_mix_pos;
191     }
192     if (new_model_str != model_str)
193         cout << "Model " << model_str << " is alias for " << new_model_str << endl;
194     model_str = new_model_str;
195 
196     //    nxsmodel = models_block->findModel(model_str);
197     //    if (nxsmodel && nxsmodel->description.find_first_of("+*") != string::npos) {
198     //        cout << "Model " << model_str << " is alias for " << nxsmodel->description << endl;
199     //        model_str = nxsmodel->description;
200     //    }
201 
202     // Detect PoMo and throw error if sequence type is PoMo but +P is
203     // not given.  This makes the model string cleaner and
204     // compareable.
205     string::size_type p_pos = posPOMO(model_str);
206     bool pomo = (p_pos != string::npos);
207 
208     if ((p_pos == string::npos) &&
209         (tree->aln->seq_type == SEQ_POMO))
210         outError("Provided alignment is exclusively used by PoMo but model string does not contain, e.g., \"+P\".");
211 
212     // Decompose model string into model_str and rate_str string.
213     size_t spec_pos = model_str.find_first_of("{+*");
214     if (spec_pos != string::npos) {
215         if (model_str[spec_pos] == '{') {
216             // Scan for the corresponding '}'.
217             size_t pos = findCloseBracket(model_str, spec_pos);
218             if (pos == string::npos)
219                 outError("Model name has wrong bracket notation '{...}'");
220             rate_str = model_str.substr(pos+1);
221             model_str = model_str.substr(0, pos+1);
222         } else {
223             rate_str = model_str.substr(spec_pos);
224             model_str = model_str.substr(0, spec_pos);
225         }
226     }
227 
228     // decompose +F from rate_str
229     string freq_str = "";
230     while ((spec_pos = rate_str.find("+F")) != string::npos) {
231         size_t end_pos = rate_str.find_first_of("+*", spec_pos+1);
232         if (end_pos == string::npos) {
233             freq_str += rate_str.substr(spec_pos);
234             rate_str = rate_str.substr(0, spec_pos);
235         } else {
236             freq_str += rate_str.substr(spec_pos, end_pos - spec_pos);
237             rate_str = rate_str.substr(0, spec_pos) + rate_str.substr(end_pos);
238         }
239     }
240 
241     // set to model_joint if set
242     if (Params::getInstance().model_joint) {
243         model_str = Params::getInstance().model_joint;
244         freq_str = "";
245         while ((spec_pos = model_str.find("+F")) != string::npos) {
246             size_t end_pos = model_str.find_first_of("+*", spec_pos+1);
247             if (end_pos == string::npos) {
248                 freq_str += model_str.substr(spec_pos);
249                 model_str = model_str.substr(0, spec_pos);
250             } else {
251                 freq_str += model_str.substr(spec_pos, end_pos - spec_pos);
252                 model_str = model_str.substr(0, spec_pos) + model_str.substr(end_pos);
253             }
254         }
255     }
256 
257     // PoMo; +NXX and +W or +S because those flags are handled when
258     // reading in the data.  Set PoMo parameters (heterozygosity).
259     size_t n_pos_start = rate_str.find("+N");
260     size_t n_pos_end   = rate_str.find_first_of("+", n_pos_start+1);
261     if (n_pos_start != string::npos) {
262         if (!pomo)
263             outError("Virtual population size can only be set with PoMo.");
264         if (n_pos_end != string::npos)
265             rate_str = rate_str.substr(0, n_pos_start)
266                 + rate_str.substr(n_pos_end);
267         else
268             rate_str = rate_str.substr(0, n_pos_start);
269     }
270 
271     size_t wb_pos = rate_str.find("+WB");
272     if (wb_pos != string::npos) {
273       if (!pomo)
274         outError("Weighted binomial sampling can only be used with PoMo.");
275       rate_str = rate_str.substr(0, wb_pos)
276         + rate_str.substr(wb_pos+3);
277     }
278     size_t wh_pos = rate_str.find("+WH");
279     if (wh_pos != string::npos) {
280         if (!pomo)
281             outError("Weighted hypergeometric sampling can only be used with PoMo.");
282         rate_str = rate_str.substr(0, wh_pos)
283             + rate_str.substr(wh_pos+3);
284     }
285     size_t s_pos = rate_str.find("+S");
286     if ( s_pos != string::npos) {
287         if (!pomo)
288             outError("Binomial sampling can only be used with PoMo.");
289         rate_str = rate_str.substr(0, s_pos)
290             + rate_str.substr(s_pos+2);
291     }
292 
293     // In case of PoMo, check that only supported flags are given.
294     if (pomo) {
295         if (rate_str.find("+ASC") != string::npos)
296             // TODO DS: This is an important feature, because then,
297             // PoMo can be applied to SNP data only.
298             outError("PoMo does not yet support ascertainment bias correction (+ASC).");
299         if (posRateFree(rate_str) != string::npos)
300             outError("PoMo does not yet support free rate models (+R).");
301         if (rate_str.find("+FMIX") != string::npos)
302             outError("PoMo does not yet support frequency mixture models (+FMIX).");
303         if (posRateHeterotachy(rate_str) != string::npos)
304             outError("PoMo does not yet support heterotachy models (+H).");
305     }
306 
307     // PoMo. The +P{}, +GXX and +I flags are interpreted during model creation.
308     // This is necessary for compatibility with mixture models. If there is no
309     // mixture model, move +P{}, +GXX and +I flags to model string. For mixture
310     // models, the heterozygosity can be set separately for each model and the
311     // +P{}, +GXX and +I flags should already be inside the model definition.
312     if (model_str.substr(0, 3) != "MIX" && pomo) {
313       // +P{} flag.
314       p_pos = posPOMO(rate_str);
315       if (p_pos != string::npos) {
316         if (rate_str[p_pos+2] == '{') {
317           string::size_type close_bracket = rate_str.find("}");
318           if (close_bracket == string::npos)
319             outError("No closing bracket in PoMo parameters.");
320           else {
321             string pomo_heterozygosity = rate_str.substr(p_pos+3,close_bracket-p_pos-3);
322             rate_str = rate_str.substr(0, p_pos) + rate_str.substr(close_bracket+1);
323             model_str += "+P{" + pomo_heterozygosity + "}";
324           }
325         }
326         else {
327           rate_str = rate_str.substr(0, p_pos) + rate_str.substr(p_pos + 2);
328           model_str += "+P";
329         }
330       }
331 
332       // +G flag.
333       size_t pomo_rate_start_pos;
334       if ((pomo_rate_start_pos = rate_str.find("+G")) != string::npos) {
335         string pomo_rate_str = "";
336         size_t pomo_rate_end_pos = rate_str.find_first_of("+*", pomo_rate_start_pos+1);
337         if (pomo_rate_end_pos == string::npos) {
338           pomo_rate_str = rate_str.substr(pomo_rate_start_pos, rate_str.length() - pomo_rate_start_pos);
339           rate_str = rate_str.substr(0, pomo_rate_start_pos);
340           model_str += pomo_rate_str;
341         } else {
342           pomo_rate_str = rate_str.substr(pomo_rate_start_pos, pomo_rate_end_pos - pomo_rate_start_pos);
343           rate_str = rate_str.substr(0, pomo_rate_start_pos) + rate_str.substr(pomo_rate_end_pos);
344           model_str += pomo_rate_str;
345         }
346       }
347 
348       // // +I flag.
349       // size_t pomo_irate_start_pos;
350       // if ((pomo_irate_start_pos = rate_str.find("+I")) != string::npos) {
351       //   string pomo_irate_str = "";
352       //   size_t pomo_irate_end_pos = rate_str.find_first_of("+*", pomo_irate_start_pos+1);
353       //   if (pomo_irate_end_pos == string::npos) {
354       //     pomo_irate_str = rate_str.substr(pomo_irate_start_pos, rate_str.length() - pomo_irate_start_pos);
355       //     rate_str = rate_str.substr(0, pomo_irate_start_pos);
356       //     model_str += pomo_irate_str;
357       //   } else {
358       //     pomo_irate_str = rate_str.substr(pomo_irate_start_pos, pomo_irate_end_pos - pomo_irate_start_pos);
359       //     rate_str = rate_str.substr(0, pomo_irate_start_pos) + rate_str.substr(pomo_irate_end_pos);
360       //     model_str += pomo_irate_str;
361       //   }
362     }
363 
364     //    nxsmodel = models_block->findModel(model_str);
365     //    if (nxsmodel && nxsmodel->description.find("MIX") != string::npos) {
366     //        cout << "Model " << model_str << " is alias for " << nxsmodel->description << endl;
367     //        model_str = nxsmodel->description;
368     //    }
369 
370     /******************** initialize state frequency ****************************/
371 
372     StateFreqType freq_type = params.freq_type;
373 
374     if (freq_type == FREQ_UNKNOWN) {
375         switch (tree->aln->seq_type) {
376         case SEQ_BINARY: freq_type = FREQ_ESTIMATE; break; // default for binary: optimized frequencies
377         case SEQ_PROTEIN: break; // let ModelProtein decide by itself
378         case SEQ_MORPH: freq_type = FREQ_EQUAL; break;
379         case SEQ_CODON: freq_type = FREQ_UNKNOWN; break;
380             break;
381         default: freq_type = FREQ_EMPIRICAL; break; // default for DNA, PoMo and others: counted frequencies from alignment
382         }
383     }
384 
385     // first handle mixture frequency
386     string::size_type posfreq = freq_str.find("+FMIX");
387     string freq_params;
388     size_t close_bracket;
389 
390     if (posfreq != string::npos) {
391         string fmix_str;
392         size_t last_pos = freq_str.find_first_of("+*", posfreq+1);
393 
394         if (last_pos == string::npos) {
395             fmix_str = freq_str.substr(posfreq);
396             freq_str = freq_str.substr(0, posfreq);
397         } else {
398             fmix_str = freq_str.substr(posfreq, last_pos-posfreq);
399             freq_str = freq_str.substr(0, posfreq) + freq_str.substr(last_pos);
400         }
401 
402         if (fmix_str[5] != OPEN_BRACKET)
403             outError("Mixture-frequency must start with +FMIX{");
404         close_bracket = fmix_str.find(CLOSE_BRACKET);
405         if (close_bracket == string::npos)
406             outError("Close bracket not found in ", fmix_str);
407         if (close_bracket != fmix_str.length()-1)
408             outError("Wrong close bracket position ", fmix_str);
409         freq_type = FREQ_MIXTURE;
410         freq_params = fmix_str.substr(6, close_bracket-6);
411     }
412 
413     // then normal frequency
414     if (freq_str.find("+FO") != string::npos)
415         posfreq = freq_str.find("+FO");
416     else if (freq_str.find("+Fo") != string::npos)
417         posfreq = freq_str.find("+Fo");
418     else
419         posfreq = freq_str.find("+F");
420 
421     bool optimize_mixmodel_weight = params.optimize_mixmodel_weight;
422 
423     if (posfreq != string::npos) {
424         string fstr;
425         size_t last_pos = freq_str.find_first_of("+*", posfreq+1);
426         if (last_pos == string::npos) {
427             fstr = freq_str.substr(posfreq);
428             freq_str = freq_str.substr(0, posfreq);
429         } else {
430             fstr = freq_str.substr(posfreq, last_pos-posfreq);
431             freq_str = freq_str.substr(0, posfreq) + freq_str.substr(last_pos);
432         }
433 
434         if (fstr.length() > 2 && fstr[2] == OPEN_BRACKET) {
435             if (freq_type == FREQ_MIXTURE)
436                 outError("Mixture frequency with user-defined frequency is not allowed");
437             close_bracket = fstr.find(CLOSE_BRACKET);
438             if (close_bracket == string::npos)
439                 outError("Close bracket not found in ", fstr);
440             if (close_bracket != fstr.length()-1)
441                 outError("Wrong close bracket position ", fstr);
442             freq_type = FREQ_USER_DEFINED;
443             freq_params = fstr.substr(3, close_bracket-3);
444         } else if (fstr == "+FC" || fstr == "+Fc" || fstr == "+F") {
445             if (freq_type == FREQ_MIXTURE) {
446                 freq_params = "empirical," + freq_params;
447                 optimize_mixmodel_weight = true;
448             } else
449                 freq_type = FREQ_EMPIRICAL;
450         } else if (fstr == "+FU" || fstr == "+Fu") {
451             if (freq_type == FREQ_MIXTURE)
452                 outError("Mixture frequency with user-defined frequency is not allowed");
453             else
454                 freq_type = FREQ_USER_DEFINED;
455         } else if (fstr == "+FQ" || fstr == "+Fq") {
456             if (freq_type == FREQ_MIXTURE)
457                 outError("Mixture frequency with equal frequency is not allowed");
458             else
459             freq_type = FREQ_EQUAL;
460         } else if (fstr == "+FO" || fstr == "+Fo") {
461             if (freq_type == FREQ_MIXTURE) {
462                 freq_params = "optimize," + freq_params;
463                 optimize_mixmodel_weight = true;
464             } else
465                 freq_type = FREQ_ESTIMATE;
466     } else if (fstr == "+F1x4" || fstr == "+F1X4") {
467             if (freq_type == FREQ_MIXTURE)
468                 outError("Mixture frequency with " + fstr + " is not allowed");
469             else
470                 freq_type = FREQ_CODON_1x4;
471         } else if (fstr == "+F3x4" || fstr == "+F3X4") {
472             if (freq_type == FREQ_MIXTURE)
473                 outError("Mixture frequency with " + fstr + " is not allowed");
474             else
475                 freq_type = FREQ_CODON_3x4;
476         } else if (fstr == "+F3x4C" || fstr == "+F3x4c" || fstr == "+F3X4C" || fstr == "+F3X4c") {
477             if (freq_type == FREQ_MIXTURE)
478                 outError("Mixture frequency with " + fstr + " is not allowed");
479             else
480                 freq_type = FREQ_CODON_3x4C;
481         } else if (fstr == "+FRY") {
482         // MDW to Minh: I don't know how these should interact with FREQ_MIXTURE,
483         // so as nearly everything else treats it as an error, I do too.
484         // BQM answer: that's fine
485             if (freq_type == FREQ_MIXTURE)
486                 outError("Mixture frequency with " + fstr + " is not allowed");
487             else
488                 freq_type = FREQ_DNA_RY;
489         } else if (fstr == "+FWS") {
490             if (freq_type == FREQ_MIXTURE)
491                 outError("Mixture frequency with " + fstr + " is not allowed");
492             else
493                 freq_type = FREQ_DNA_WS;
494         } else if (fstr == "+FMK") {
495             if (freq_type == FREQ_MIXTURE)
496                 outError("Mixture frequency with " + fstr + " is not allowed");
497             else
498                 freq_type = FREQ_DNA_MK;
499         } else {
500             // might be "+F####" where # are digits
501             try {
502                 freq_type = parseStateFreqDigits(fstr.substr(2)); // throws an error if not in +F#### format
503             } catch (...) {
504                 outError("Unknown state frequency type ",fstr);
505             }
506         }
507 //          model_str = model_str.substr(0, posfreq);
508         }
509 
510     /******************** initialize model ****************************/
511 
512     if (tree->aln->site_state_freq.empty()) {
513         if (model_str.substr(0, 3) == "MIX" || freq_type == FREQ_MIXTURE) {
514             string model_list;
515             if (model_str.substr(0, 3) == "MIX") {
516                 if (model_str[3] != OPEN_BRACKET)
517                     outError("Mixture model name must start with 'MIX{'");
518                 if (model_str.rfind(CLOSE_BRACKET) != model_str.length()-1)
519                     outError("Close bracket not found at the end of ", model_str);
520                 model_list = model_str.substr(4, model_str.length()-5);
521                 model_str = model_str.substr(0, 3);
522             }
523             model = new ModelMixture(model_name, model_str, model_list, models_block, freq_type, freq_params, tree, optimize_mixmodel_weight);
524         } else {
525             //            string model_desc;
526             //            NxsModel *nxsmodel = models_block->findModel(model_str);
527             //            if (nxsmodel) model_desc = nxsmodel->description;
528             model = createModel(model_str, models_block, freq_type, freq_params, tree);
529         }
530 //        fused_mix_rate &= model->isMixture() && site_rate->getNRate() > 1;
531     } else {
532         // site-specific model
533         if (model_str == "JC" || model_str == "POISSON")
534             outError("JC is not suitable for site-specific model");
535         model = new ModelSet(model_str.c_str(), tree);
536         ModelSet *models = (ModelSet*)model; // assign pointer for convenience
537         models->init((params.freq_type != FREQ_UNKNOWN) ? params.freq_type : FREQ_EMPIRICAL);
538         int i;
539         models->pattern_model_map.resize(tree->aln->getNPattern(), -1);
540         for (i = 0; i < tree->aln->getNSite(); i++) {
541             models->pattern_model_map[tree->aln->getPatternID(i)] = tree->aln->site_model[i];
542             //cout << "site " << i << " ptn " << tree->aln->getPatternID(i) << " -> model " << site_model[i] << endl;
543         }
544         double *state_freq = new double[model->num_states];
545         double *rates = new double[model->getNumRateEntries()];
546         for (i = 0; i < tree->aln->site_state_freq.size(); i++) {
547             ModelMarkov *modeli;
548             if (i == 0) {
549                 modeli = (ModelMarkov*)createModel(model_str, models_block, (params.freq_type != FREQ_UNKNOWN) ? params.freq_type : FREQ_EMPIRICAL, "", tree);
550                 modeli->getStateFrequency(state_freq);
551                 modeli->getRateMatrix(rates);
552             } else {
553                 modeli = (ModelMarkov*)createModel(model_str, models_block, FREQ_EQUAL, "", tree);
554                 modeli->setStateFrequency(state_freq);
555                 modeli->setRateMatrix(rates);
556             }
557             if (tree->aln->site_state_freq[i])
558                 modeli->setStateFrequency (tree->aln->site_state_freq[i]);
559 
560             modeli->init(FREQ_USER_DEFINED);
561             models->push_back(modeli);
562         }
563         delete [] rates;
564         delete [] state_freq;
565 
566         models->joinEigenMemory();
567         models->decomposeRateMatrix();
568 
569         // delete information of the old alignment
570 //        tree->aln->ordered_pattern.clear();
571 //        tree->deleteAllPartialLh();
572     }
573 
574 //    if (model->isMixture())
575 //        cout << "Mixture model with " << model->getNMixtures() << " components!" << endl;
576 
577     /******************** initialize ascertainment bias correction model ****************************/
578 
579     string::size_type posasc;
580 
581     if ((posasc = rate_str.find("+ASC_INF")) != string::npos) {
582         // ascertainment bias correction
583         ASC_type = ASC_INFORMATIVE;
584         tree->aln->getUnobservedConstPatterns(ASC_type, unobserved_ptns);
585 
586         // rebuild the seq_states to contain states of unobserved constant patterns
587         //tree->aln->buildSeqStates(model->seq_states, true);
588         if (tree->aln->num_informative_sites != tree->getAlnNSite()) {
589             if (!params.partition_file) {
590                 string infsites_file = ((string)params.out_prefix + ".infsites.phy");
591                 tree->aln->printAlignment(params.aln_output_format, infsites_file.c_str(), false, NULL, EXCLUDE_UNINF);
592                 cerr << "For your convenience alignment with parsimony-informative sites printed to " << infsites_file << endl;
593             }
594             outError("Invalid use of +ASC_INF because of " + convertIntToString(tree->getAlnNSite() - tree->aln->num_informative_sites) +
595                      " parsimony-uninformative sites in the alignment");
596         }
597         if (verbose_mode >= VB_MED)
598             cout << "Ascertainment bias correction: " << unobserved_ptns.size() << " unobservable uninformative patterns"<< endl;
599         rate_str = rate_str.substr(0, posasc) + rate_str.substr(posasc+8);
600     } else if ((posasc = rate_str.find("+ASC_MIS")) != string::npos) {
601         // initialize Holder's ascertainment bias correction model
602         ASC_type = ASC_VARIANT_MISSING;
603         tree->aln->getUnobservedConstPatterns(ASC_type, unobserved_ptns);
604         // rebuild the seq_states to contain states of unobserved constant patterns
605         //tree->aln->buildSeqStates(model->seq_states, true);
606         if (tree->aln->frac_invariant_sites > 0) {
607             if (!params.partition_file) {
608                 string varsites_file = ((string)params.out_prefix + ".varsites.phy");
609                 tree->aln->printAlignment(params.aln_output_format, varsites_file.c_str(), false, NULL, EXCLUDE_INVAR);
610                 cerr << "For your convenience alignment with variable sites printed to " << varsites_file << endl;
611             }
612             outError("Invalid use of +ASC_MIS because of " + convertIntToString(tree->aln->frac_invariant_sites*tree->aln->getNSite()) +
613                      " invariant sites in the alignment");
614         }
615         if (verbose_mode >= VB_MED)
616             cout << "Holder's ascertainment bias correction: " << unobserved_ptns.size() << " unobservable constant patterns" << endl;
617         rate_str = rate_str.substr(0, posasc) + rate_str.substr(posasc+8);
618     } else if ((posasc = rate_str.find("+ASC")) != string::npos) {
619         // ascertainment bias correction
620         ASC_type = ASC_VARIANT;
621         tree->aln->getUnobservedConstPatterns(ASC_type, unobserved_ptns);
622 
623         // delete rarely observed state
624         for (int i = unobserved_ptns.size()-1; i >= 0; i--)
625             if (model->state_freq[(int)unobserved_ptns[i][0]] < 1e-8)
626                 unobserved_ptns.erase(unobserved_ptns.begin() + i);
627 
628         // rebuild the seq_states to contain states of unobserved constant patterns
629         //tree->aln->buildSeqStates(model->seq_states, true);
630 //        if (unobserved_ptns.size() <= 0)
631 //            outError("Invalid use of +ASC because all constant patterns are observed in the alignment");
632         if (tree->aln->frac_invariant_sites > 0) {
633 //            cerr << tree->aln->frac_invariant_sites*tree->aln->getNSite() << " invariant sites observed in the alignment" << endl;
634 //            for (Alignment::iterator pit = tree->aln->begin(); pit != tree->aln->end(); pit++)
635 //                if (pit->isInvariant()) {
636 //                    string pat_str = "";
637 //                    for (Pattern::iterator it = pit->begin(); it != pit->end(); it++)
638 //                        pat_str += tree->aln->convertStateBackStr(*it);
639 //                    cerr << pat_str << " is invariant site pattern" << endl;
640 //                }
641             if (!params.partition_file) {
642                 string varsites_file = ((string)params.out_prefix + ".varsites.phy");
643                 tree->aln->printAlignment(params.aln_output_format, varsites_file.c_str(), false, NULL, EXCLUDE_INVAR);
644                 cerr << "For your convenience alignment with variable sites printed to " << varsites_file << endl;
645             }
646             outError("Invalid use of +ASC because of " + convertIntToString(tree->aln->frac_invariant_sites*tree->aln->getNSite()) +
647                 " invariant sites in the alignment");
648         }
649         if (verbose_mode >= VB_MED)
650             cout << "Ascertainment bias correction: " << unobserved_ptns.size() << " unobservable constant patterns"<< endl;
651 		rate_str = rate_str.substr(0, posasc) + rate_str.substr(posasc+4);
652     } else {
653         //tree->aln->buildSeqStates(model->seq_states, false);
654     }
655 
656     /******************** initialize site rate heterogeneity ****************************/
657 
658     string::size_type posI = rate_str.find("+I");
659     string::size_type posG = rate_str.find("+G");
660     string::size_type posG2 = rate_str.find("*G");
661     if (posG != string::npos && posG2 != string::npos) {
662         cout << "NOTE: both +G and *G were specified, continue with "
663             << ((posG < posG2)? rate_str.substr(posG,2) : rate_str.substr(posG2,2)) << endl;
664     }
665     if (posG2 != string::npos && posG2 < posG) {
666         posG = posG2;
667         fused_mix_rate = true;
668     }
669 
670     string::size_type posR = rate_str.find("+R"); // FreeRate model
671     string::size_type posR2 = rate_str.find("*R"); // FreeRate model
672 
673     if (posG != string::npos && (posR != string::npos || posR2 != string::npos)) {
674         outWarning("Both Gamma and FreeRate models were specified, continue with FreeRate model");
675         posG = string::npos;
676         fused_mix_rate = false;
677     }
678 
679     if (posR != string::npos && posR2 != string::npos) {
680         cout << "NOTE: both +R and *R were specified, continue with "
681             << ((posR < posR2)? rate_str.substr(posR,2) : rate_str.substr(posR2,2)) << endl;
682     }
683 
684     if (posR2 != string::npos && posR2 < posR) {
685         posR = posR2;
686         fused_mix_rate = true;
687     }
688 
689     string::size_type posH = rate_str.find("+H"); // heterotachy model
690     string::size_type posH2 = rate_str.find("*H"); // heterotachy model
691 
692     if (posG != string::npos && (posH != string::npos || posH2 != string::npos)) {
693         outWarning("Both Gamma and heterotachy models were specified, continue with heterotachy model");
694         posG = string::npos;
695         fused_mix_rate = false;
696     }
697 
698     if (posR != string::npos && (posH != string::npos || posH2 != string::npos)) {
699         outWarning("Both FreeRate and heterotachy models were specified, continue with heterotachy model");
700         posR = string::npos;
701         fused_mix_rate = false;
702     }
703 
704     if (posH != string::npos && posH2 != string::npos) {
705         cout << "NOTE: both +H and *H were specified, continue with "
706             << ((posH < posH2)? rate_str.substr(posH,2) : rate_str.substr(posH2,2)) << endl;
707     }
708     if (posH2 != string::npos && posH2 < posH) {
709         posH = posH2;
710         fused_mix_rate = true;
711     }
712 
713     string::size_type posX;
714     /* create site-rate heterogeneity */
715     int num_rate_cats = params.num_rate_cats;
716     if (fused_mix_rate && model->isMixture()) num_rate_cats = model->getNMixtures();
717     double gamma_shape = params.gamma_shape;
718     double p_invar_sites = params.p_invar_sites;
719     string freerate_params = "";
720     if (posI != string::npos) {
721         // invariable site model
722         if (rate_str.length() > posI+2 && rate_str[posI+2] == OPEN_BRACKET) {
723             close_bracket = rate_str.find(CLOSE_BRACKET, posI);
724             if (close_bracket == string::npos)
725                 outError("Close bracket not found in ", rate_str);
726             p_invar_sites = convert_double(rate_str.substr(posI+3, close_bracket-posI-3).c_str());
727             if (p_invar_sites < 0 || p_invar_sites >= 1)
728                 outError("p_invar must be in [0,1)");
729         } else if (rate_str.length() > posI+2 && rate_str[posI+2] != '+' && rate_str[posI+2] != '*')
730             outError("Wrong model name ", rate_str);
731     }
732     if (posG != string::npos) {
733         // Gamma rate model
734         int end_pos = 0;
735         if (rate_str.length() > posG+2 && isdigit(rate_str[posG+2])) {
736             num_rate_cats = convert_int(rate_str.substr(posG+2).c_str(), end_pos);
737             if (num_rate_cats < 1) outError("Wrong number of rate categories");
738         }
739         if (rate_str.length() > posG+2+end_pos && rate_str[posG+2+end_pos] == OPEN_BRACKET) {
740             close_bracket = rate_str.find(CLOSE_BRACKET, posG);
741             if (close_bracket == string::npos)
742                 outError("Close bracket not found in ", rate_str);
743             gamma_shape = convert_double(rate_str.substr(posG+3+end_pos, close_bracket-posG-3-end_pos).c_str());
744 //            if (gamma_shape < MIN_GAMMA_SHAPE || gamma_shape > MAX_GAMMA_SHAPE) {
745 //                stringstream str;
746 //                str << "Gamma shape parameter " << gamma_shape << "out of range ["
747 //                        << MIN_GAMMA_SHAPE << ',' << MAX_GAMMA_SHAPE << "]" << endl;
748 //                outError(str.str());
749 //            }
750         } else if (rate_str.length() > posG+2+end_pos && rate_str[posG+2+end_pos] != '+')
751             outError("Wrong model name ", rate_str);
752     }
753     if (posR != string::npos) {
754         // FreeRate model
755         int end_pos = 0;
756         if (rate_str.length() > posR+2 && isdigit(rate_str[posR+2])) {
757             num_rate_cats = convert_int(rate_str.substr(posR+2).c_str(), end_pos);
758                 if (num_rate_cats < 1) outError("Wrong number of rate categories");
759             }
760         if (rate_str.length() > posR+2+end_pos && rate_str[posR+2+end_pos] == OPEN_BRACKET) {
761             close_bracket = rate_str.find(CLOSE_BRACKET, posR);
762             if (close_bracket == string::npos)
763                 outError("Close bracket not found in ", rate_str);
764             freerate_params = rate_str.substr(posR+3+end_pos, close_bracket-posR-3-end_pos).c_str();
765         } else if (rate_str.length() > posR+2+end_pos && rate_str[posR+2+end_pos] != '+')
766             outError("Wrong model name ", rate_str);
767     }
768 
769     string heterotachy_params = "";
770     if (posH != string::npos) {
771         // Heterotachy model
772         int end_pos = 0;
773         if (rate_str.length() > posH+2 && isdigit(rate_str[posH+2])) {
774             num_rate_cats = convert_int(rate_str.substr(posH+2).c_str(), end_pos);
775                 if (num_rate_cats < 1) outError("Wrong number of rate categories");
776         } else {
777             if (!model->isMixture() || !fused_mix_rate)
778                 outError("Please specify number of heterotachy classes (e.g., +H2)");
779         }
780         if (rate_str.length() > posH+2+end_pos && rate_str[posH+2+end_pos] == OPEN_BRACKET) {
781             close_bracket = rate_str.find(CLOSE_BRACKET, posH);
782             if (close_bracket == string::npos)
783                 outError("Close bracket not found in ", rate_str);
784             heterotachy_params = rate_str.substr(posH+3+end_pos, close_bracket-posH-3-end_pos).c_str();
785         } else if (rate_str.length() > posH+2+end_pos && rate_str[posH+2+end_pos] != '+')
786             outError("Wrong model name ", rate_str);
787     }
788 
789 
790     if (rate_str.find('+') != string::npos || rate_str.find('*') != string::npos) {
791         //string rate_str = model_str.substr(pos);
792         if (posI != string::npos && posH != string::npos) {
793             site_rate = new RateHeterotachyInvar(num_rate_cats, heterotachy_params, p_invar_sites, tree);
794         } else if (posH != string::npos) {
795             site_rate = new RateHeterotachy(num_rate_cats, heterotachy_params, tree);
796         } else if (posI != string::npos && posG != string::npos) {
797             site_rate = new RateGammaInvar(num_rate_cats, gamma_shape, params.gamma_median,
798                     p_invar_sites, params.optimize_alg_gammai, tree, false);
799         } else if (posI != string::npos && posR != string::npos) {
800             site_rate = new RateFreeInvar(num_rate_cats, gamma_shape, freerate_params, !fused_mix_rate, p_invar_sites, params.optimize_alg, tree);
801         } else if (posI != string::npos) {
802             site_rate = new RateInvar(p_invar_sites, tree);
803         } else if (posG != string::npos) {
804             site_rate = new RateGamma(num_rate_cats, gamma_shape, params.gamma_median, tree);
805         } else if (posR != string::npos) {
806             site_rate = new RateFree(num_rate_cats, gamma_shape, freerate_params, !fused_mix_rate, params.optimize_alg, tree);
807 //        } else if ((posX = rate_str.find("+M")) != string::npos) {
808 //            tree->setLikelihoodKernel(LK_NORMAL);
809 //            params.rate_mh_type = true;
810 //            if (rate_str.length() > posX+2 && isdigit(rate_str[posX+2])) {
811 //                num_rate_cats = convert_int(rate_str.substr(posX+2).c_str());
812 //                if (num_rate_cats < 0) outError("Wrong number of rate categories");
813 //            } else num_rate_cats = -1;
814 //            if (num_rate_cats >= 0)
815 //                site_rate = new RateMeyerDiscrete(num_rate_cats, params.mcat_type,
816 //                    params.rate_file, tree, params.rate_mh_type);
817 //            else
818 //                site_rate = new RateMeyerHaeseler(params.rate_file, tree, params.rate_mh_type);
819 //            site_rate->setTree(tree);
820 //        } else if ((posX = rate_str.find("+D")) != string::npos) {
821 //            tree->setLikelihoodKernel(LK_NORMAL);
822 //            params.rate_mh_type = false;
823 //            if (rate_str.length() > posX+2 && isdigit(rate_str[posX+2])) {
824 //                num_rate_cats = convert_int(rate_str.substr(posX+2).c_str());
825 //                if (num_rate_cats < 0) outError("Wrong number of rate categories");
826 //            } else num_rate_cats = -1;
827 //            if (num_rate_cats >= 0)
828 //                site_rate = new RateMeyerDiscrete(num_rate_cats, params.mcat_type,
829 //                    params.rate_file, tree, params.rate_mh_type);
830 //            else
831 //                site_rate = new RateMeyerHaeseler(params.rate_file, tree, params.rate_mh_type);
832 //            site_rate->setTree(tree);
833 //        } else if ((posX = rate_str.find("+NGS")) != string::npos) {
834 //            tree->setLikelihoodKernel(LK_NORMAL);
835 //            if (rate_str.length() > posX+4 && isdigit(rate_str[posX+4])) {
836 //                num_rate_cats = convert_int(rate_str.substr(posX+4).c_str());
837 //                if (num_rate_cats < 0) outError("Wrong number of rate categories");
838 //            } else num_rate_cats = -1;
839 //            site_rate = new NGSRateCat(tree, num_rate_cats);
840 //            site_rate->setTree(tree);
841 //        } else if ((posX = rate_str.find("+NGS")) != string::npos) {
842 //            tree->setLikelihoodKernel(LK_NORMAL);
843 //            if (rate_str.length() > posX+4 && isdigit(rate_str[posX+4])) {
844 //                num_rate_cats = convert_int(rate_str.substr(posX+4).c_str());
845 //                if (num_rate_cats < 0) outError("Wrong number of rate categories");
846 //            } else num_rate_cats = -1;
847 //            site_rate = new NGSRate(tree);
848 //            site_rate->setTree(tree);
849         } else if ((posX = rate_str.find("+K")) != string::npos) {
850             if (rate_str.length() > posX+2 && isdigit(rate_str[posX+2])) {
851                 num_rate_cats = convert_int(rate_str.substr(posX+2).c_str());
852                 if (num_rate_cats < 1) outError("Wrong number of rate categories");
853             }
854             site_rate = new RateKategory(num_rate_cats, tree);
855         } else
856             outError("Invalid rate heterogeneity type");
857 //        if (model_str.find('+') != string::npos)
858 //            model_str = model_str.substr(0, model_str.find('+'));
859 //        else
860 //            model_str = model_str.substr(0, model_str.find('*'));
861     } else {
862         site_rate = new RateHeterogeneity();
863         site_rate->setTree(tree);
864     }
865 
866     if (fused_mix_rate) {
867         if (!model->isMixture()) {
868             if (verbose_mode >= VB_MED)
869                 cout << endl << "NOTE: Using mixture model with unlinked " << model_str << " parameters" << endl;
870             string model_list = model_str;
871             delete model;
872             for (int i = 1; i < site_rate->getNRate(); i++)
873                 model_list += "," + model_str;
874             model = new ModelMixture(model_name, model_str, model_list, models_block, freq_type, freq_params, tree, optimize_mixmodel_weight);
875         }
876         if (model->getNMixtures() != site_rate->getNRate())
877             outError("Mixture model and site rate model do not have the same number of categories");
878 //        if (!tree->isMixlen()) {
879             // reset mixture model
880             model->setFixMixtureWeight(true);
881             int mix, nmix = model->getNMixtures();
882             for (mix = 0; mix < nmix; mix++) {
883                 ((ModelMarkov*)model->getMixtureClass(mix))->total_num_subst = 1.0;
884                 model->setMixtureWeight(mix, 1.0);
885             }
886             model->decomposeRateMatrix();
887 //        } else {
888 //            site_rate->setFixParams(1);
889 //            int c, ncat = site_rate->getNRate();
890 //            for (c = 0; c < ncat; c++)
891 //                site_rate->setProp(c, 1.0);
892 //        }
893     }
894 
895     tree->discardSaturatedSite(params.discard_saturated_site);
896 
897     } catch (const char* str) {
898         outError(str);
899     }
900 
901 }
902 
setCheckpoint(Checkpoint * checkpoint)903 void ModelFactory::setCheckpoint(Checkpoint *checkpoint) {
904     CheckpointFactory::setCheckpoint(checkpoint);
905     model->setCheckpoint(checkpoint);
906     site_rate->setCheckpoint(checkpoint);
907 }
908 
startCheckpoint()909 void ModelFactory::startCheckpoint() {
910     checkpoint->startStruct("ModelFactory");
911 }
912 
saveCheckpoint()913 void ModelFactory::saveCheckpoint() {
914     model->saveCheckpoint();
915     site_rate->saveCheckpoint();
916     startCheckpoint();
917 //    CKP_SAVE(fused_mix_rate);
918 //    CKP_SAVE(unobserved_ptns);
919 //    CKP_SAVE(joint_optimize);
920     endCheckpoint();
921     CheckpointFactory::saveCheckpoint();
922 }
923 
restoreCheckpoint()924 void ModelFactory::restoreCheckpoint() {
925     model->restoreCheckpoint();
926     site_rate->restoreCheckpoint();
927     startCheckpoint();
928 //    CKP_RESTORE(fused_mix_rate);
929 //    CKP_RESTORE(unobserved_ptns);
930 //    CKP_RESTORE(joint_optimize);
931     endCheckpoint();
932 }
933 
getNParameters(int brlen_type)934 int ModelFactory::getNParameters(int brlen_type) {
935     int df = model->getNDim() + model->getNDimFreq() + site_rate->getNDim() +
936         site_rate->getTree()->getNBranchParameters(brlen_type);
937 
938     return df;
939 }
940 /*
941 double ModelFactory::initGTRGammaIParameters(RateHeterogeneity *rate, ModelSubst *model, double initAlpha,
942                                            double initPInvar, double *initRates, double *initStateFreqs)  {
943 
944     RateHeterogeneity* rateGammaInvar = rate;
945     ModelMarkov* modelGTR = (ModelMarkov*)(model);
946     modelGTR->setRateMatrix(initRates);
947     modelGTR->setStateFrequency(initStateFreqs);
948     rateGammaInvar->setGammaShape(initAlpha);
949     rateGammaInvar->setPInvar(initPInvar);
950     modelGTR->decomposeRateMatrix();
951     site_rate->phylo_tree->clearAllPartialLH();
952     return site_rate->phylo_tree->computeLikelihood();
953 }
954 */
955 
optimizeParametersOnly(int num_steps,double gradient_epsilon,double cur_logl)956 double ModelFactory::optimizeParametersOnly(int num_steps, double gradient_epsilon, double cur_logl) {
957     double logl;
958     /* Optimize substitution and heterogeneity rates independently */
959     if (!joint_optimize) {
960         // more steps for fused mix rate model
961         int steps;
962         if (false && fused_mix_rate && model->getNDim() > 0 && site_rate->getNDim() > 0) {
963             model->setOptimizeSteps(1);
964             site_rate->setOptimizeSteps(1);
965             steps = max(model->getNDim()+site_rate->getNDim(), num_steps) * 3;
966         } else {
967             steps = 1;
968         }
969         double prev_logl = cur_logl;
970         for (int step = 0; step < steps; step++) {
971             double model_lh = 0.0;
972             // only optimized if model is not linked
973             model_lh = model->optimizeParameters(gradient_epsilon);
974 
975             double rate_lh = site_rate->optimizeParameters(gradient_epsilon);
976 
977             if (rate_lh == 0.0)
978                 logl = model_lh;
979             else
980                 logl = rate_lh;
981             if (logl <= prev_logl + gradient_epsilon)
982                 break;
983             prev_logl = logl;
984         }
985     } else {
986         /* Optimize substitution and heterogeneity rates jointly using BFGS */
987         logl = optimizeAllParameters(gradient_epsilon);
988     }
989     return logl;
990 }
991 
optimizeAllParameters(double gradient_epsilon)992 double ModelFactory::optimizeAllParameters(double gradient_epsilon) {
993     int ndim = getNDim();
994 
995     // return if nothing to be optimized
996     if (ndim == 0) return 0.0;
997 
998     double *variables = new double[ndim+1];
999     double *upper_bound = new double[ndim+1];
1000     double *lower_bound = new double[ndim+1];
1001     bool *bound_check = new bool[ndim+1];
1002     int i;
1003     double score;
1004 
1005     // setup the bounds for model
1006     setVariables(variables);
1007     int model_ndim = model->getNDim();
1008     for (i = 1; i <= model_ndim; i++) {
1009         //cout << variables[i] << endl;
1010         lower_bound[i] = MIN_RATE;
1011         upper_bound[i] = MAX_RATE;
1012         bound_check[i] = false;
1013     }
1014 
1015     if (model->freq_type == FREQ_ESTIMATE) {
1016         for (i = model_ndim- model->num_states+2; i <= model_ndim; i++)
1017             upper_bound[i] = 1.0;
1018     }
1019 
1020     // setup the bounds for site_rate
1021     site_rate->setBounds(lower_bound+model_ndim, upper_bound+model_ndim, bound_check+model_ndim);
1022 
1023     score = -minimizeMultiDimen(variables, ndim, lower_bound, upper_bound, bound_check, max(gradient_epsilon, TOL_RATE));
1024 
1025     getVariables(variables);
1026     //if (freq_type == FREQ_ESTIMATE) scaleStateFreq(true);
1027     model->decomposeRateMatrix();
1028     site_rate->phylo_tree->clearAllPartialLH();
1029 
1030     score = site_rate->phylo_tree->computeLikelihood();
1031 
1032     delete [] bound_check;
1033     delete [] lower_bound;
1034     delete [] upper_bound;
1035     delete [] variables;
1036 
1037     return score;
1038 }
1039 
optimizeParametersGammaInvar(int fixed_len,bool write_info,double logl_epsilon,double gradient_epsilon)1040 double ModelFactory::optimizeParametersGammaInvar(int fixed_len, bool write_info, double logl_epsilon, double gradient_epsilon) {
1041     if (!site_rate->isGammai() || site_rate->isFixPInvar() || site_rate->isFixGammaShape() || site_rate->getTree()->aln->frac_const_sites == 0.0 || model->isMixture())
1042         return optimizeParameters(fixed_len, write_info, logl_epsilon, gradient_epsilon);
1043 
1044     double begin_time = getRealTime();
1045 
1046     PhyloTree *tree = site_rate->getTree();
1047     double frac_const = tree->aln->frac_const_sites;
1048     tree->setCurScore(tree->computeLikelihood());
1049 
1050     /* Back up branch lengths and substitutional rates */
1051     DoubleVector initBranLens;
1052     DoubleVector bestLens;
1053     tree->saveBranchLengths(initBranLens);
1054     bestLens = initBranLens;
1055 //    int numRateEntries = tree->getModel()->getNumRateEntries();
1056     Checkpoint *model_ckp = new Checkpoint;
1057     Checkpoint *best_ckp = new Checkpoint;
1058     Checkpoint *saved_ckp = model->getCheckpoint();
1059     *model_ckp = *saved_ckp;
1060 //    double *rates = new double[numRateEntries];
1061 //    double *bestRates = new double[numRateEntries];
1062 //    tree->getModel()->getRateMatrix(rates);
1063 //    int numStates = tree->aln->num_states;
1064 //    double *state_freqs = new double[numStates];
1065 //    tree->getModel()->getStateFrequency(state_freqs);
1066 
1067     /* Best estimates found */
1068 //    double *bestStateFreqs =  new double[numStates];
1069     double bestLogl = -DBL_MAX;
1070     double bestAlpha = 0.0;
1071     double bestPInvar = 0.0;
1072 
1073     double testInterval = (frac_const - MIN_PINVAR * 2) / 9;
1074     double initPInv = MIN_PINVAR;
1075     double initAlpha = site_rate->getGammaShape();
1076 
1077     if (Params::getInstance().opt_gammai_fast) {
1078         initPInv = frac_const/2;
1079         bool stop = false;
1080         while(!stop) {
1081             if (write_info) {
1082                 cout << endl;
1083                 cout << "Testing with init. pinv = " << initPInv << " / init. alpha = "  << initAlpha << endl;
1084             }
1085 
1086             vector<double> estResults = optimizeGammaInvWithInitValue(fixed_len, logl_epsilon, gradient_epsilon,
1087                                                                    initPInv, initAlpha, initBranLens, model_ckp);
1088 
1089 
1090             if (write_info) {
1091                 cout << "Est. p_inv: " << estResults[0] << " / Est. gamma shape: " << estResults[1]
1092                 << " / Logl: " << estResults[2] << endl;
1093             }
1094 
1095             if (estResults[2] > bestLogl) {
1096                 bestLogl = estResults[2];
1097                 bestAlpha = estResults[1];
1098                 bestPInvar = estResults[0];
1099                 bestLens.clear();
1100                 tree->saveBranchLengths(bestLens);
1101                 model->setCheckpoint(best_ckp);
1102                 model->saveCheckpoint();
1103                 model->setCheckpoint(saved_ckp);
1104 //                *best_ckp = *saved_ckp;
1105 
1106 //                tree->getModel()->getRateMatrix(bestRates);
1107 //                tree->getModel()->getStateFrequency(bestStateFreqs);
1108                 if (estResults[0] < initPInv) {
1109                     initPInv = estResults[0] - testInterval;
1110                     if (initPInv < 0.0)
1111                         initPInv = 0.0;
1112                 } else {
1113                     initPInv = estResults[0] + testInterval;
1114                     if (initPInv > frac_const)
1115                         initPInv = frac_const;
1116                 }
1117                 //cout << "New initPInv = " << initPInv << endl;
1118             }  else {
1119                 stop = true;
1120             }
1121         }
1122     } else {
1123         // Now perform testing different initial p_inv values
1124         if (write_info)
1125             cout << "Thoroughly optimizing +I+G parameters from 10 start values..." << endl;
1126         while (initPInv <= frac_const) {
1127             vector<double> estResults; // vector of p_inv, alpha and logl
1128             if (Params::getInstance().opt_gammai_keep_bran)
1129                 estResults = optimizeGammaInvWithInitValue(fixed_len, logl_epsilon, gradient_epsilon,
1130                     initPInv, initAlpha, bestLens, model_ckp);
1131             else
1132                 estResults = optimizeGammaInvWithInitValue(fixed_len, logl_epsilon, gradient_epsilon,
1133                     initPInv, initAlpha, initBranLens, model_ckp);
1134             if (write_info) {
1135                 cout << "Init pinv, alpha: " << initPInv << ", "  << initAlpha
1136                      << " / Estimate: " << estResults[0] << ", " << estResults[1]
1137                      << " / LogL: " << estResults[2] << endl;
1138             }
1139 
1140             initPInv = initPInv + testInterval;
1141 
1142             if (estResults[2] > bestLogl) {
1143                 bestLogl = estResults[2];
1144                 bestAlpha = estResults[1];
1145                 bestPInvar = estResults[0];
1146                 bestLens.clear();
1147                 tree->saveBranchLengths(bestLens);
1148                 model->setCheckpoint(best_ckp);
1149                 model->saveCheckpoint();
1150                 model->setCheckpoint(saved_ckp);
1151 //                *best_ckp = *saved_ckp;
1152 
1153 //                tree->getModel()->getRateMatrix(bestRates);
1154 //                tree->getModel()->getStateFrequency(bestStateFreqs);
1155             }
1156         }
1157     }
1158 
1159     site_rate->setGammaShape(bestAlpha);
1160     site_rate->setPInvar(bestPInvar);
1161 
1162     // -- Mon Apr 17 21:12:14 BST 2017
1163     // DONE Minh, merged correctly
1164     model->setCheckpoint(best_ckp);
1165     model->restoreCheckpoint();
1166     model->setCheckpoint(saved_ckp);
1167     // ((ModelGTR*) tree->getModel())->setRateMatrix(bestRates);
1168     // ((ModelGTR*) tree->getModel())->setStateFrequency(bestStateFreqs);
1169     // --
1170 
1171     tree->restoreBranchLengths(bestLens);
1172     // tree->getModel()->decomposeRateMatrix();
1173 
1174     tree->clearAllPartialLH();
1175     tree->setCurScore(tree->computeLikelihood());
1176     if (write_info) {
1177         cout << "Optimal pinv,alpha: " << bestPInvar << ", " << bestAlpha << " / ";
1178         cout << "LogL: " << tree->getCurScore() << endl << endl;
1179     }
1180     ASSERT(fabs(tree->getCurScore() - bestLogl) < 1.0);
1181 
1182 //    delete [] rates;
1183 //    delete [] state_freqs;
1184 //    delete [] bestRates;
1185 //    delete [] bestStateFreqs;
1186 
1187     delete model_ckp;
1188     delete best_ckp;
1189 
1190     double elapsed_secs = getRealTime() - begin_time;
1191     if (write_info)
1192         cout << "Parameters optimization took " << elapsed_secs << " sec" << endl;
1193 
1194     // updating global variable is not safe!
1195 //    Params::getInstance().testAlpha = false;
1196 
1197     // 2016-03-14: this was missing!
1198     return tree->getCurScore();
1199 }
1200 
optimizeGammaInvWithInitValue(int fixed_len,double logl_epsilon,double gradient_epsilon,double initPInv,double initAlpha,DoubleVector & lenvec,Checkpoint * model_ckp)1201 vector<double> ModelFactory::optimizeGammaInvWithInitValue(int fixed_len, double logl_epsilon, double gradient_epsilon,
1202                                                  double initPInv, double initAlpha,
1203                                                  DoubleVector &lenvec, Checkpoint *model_ckp) {
1204     PhyloTree *tree = site_rate->getTree();
1205     tree->restoreBranchLengths(lenvec);
1206 
1207     // -- Mon Apr 17 21:12:24 BST 2017
1208     // DONE Minh: merged correctly
1209     Checkpoint *saved_ckp = model->getCheckpoint();
1210     model->setCheckpoint(model_ckp);
1211     model->restoreCheckpoint();
1212     model->setCheckpoint(saved_ckp);
1213     site_rate->setPInvar(initPInv);
1214     site_rate->setGammaShape(initAlpha);
1215     // --
1216 
1217     tree->clearAllPartialLH();
1218     optimizeParameters(fixed_len, false, logl_epsilon, gradient_epsilon);
1219 
1220     vector<double> estResults;
1221     double estPInv = site_rate->getPInvar();
1222     double estAlpha = site_rate->getGammaShape();
1223     double logl = tree->getCurScore();
1224     estResults.push_back(estPInv);
1225     estResults.push_back(estAlpha);
1226     estResults.push_back(logl);
1227     return estResults;
1228 }
1229 
1230 
optimizeParameters(int fixed_len,bool write_info,double logl_epsilon,double gradient_epsilon)1231 double ModelFactory::optimizeParameters(int fixed_len, bool write_info,
1232                                         double logl_epsilon, double gradient_epsilon) {
1233     ASSERT(model);
1234     ASSERT(site_rate);
1235 
1236 //    double defaultEpsilon = logl_epsilon;
1237 
1238     double begin_time = getRealTime();
1239     double cur_lh;
1240     PhyloTree *tree = site_rate->getTree();
1241     ASSERT(tree);
1242 
1243     stopStoringTransMatrix();
1244     // modified by Thomas Wong on Sept 11, 15
1245     // no optimization of branch length in the first round
1246     cur_lh = tree->computeLikelihood();
1247     tree->setCurScore(cur_lh);
1248     if (verbose_mode >= VB_MED || write_info) {
1249     int p = -1;
1250 
1251     // SET precision to 17 (temporarily)
1252     if (verbose_mode >= VB_DEBUG) p = cout.precision(17);
1253 
1254     // PRINT Log-Likelihood
1255     cout << "1. Initial log-likelihood: " << cur_lh << endl;
1256 
1257     // RESTORE previous precision
1258     if (verbose_mode >= VB_DEBUG) cout.precision(p);
1259 
1260         if (verbose_mode >= VB_MAX) {
1261             tree->printTree(cout);
1262             cout << endl;
1263         }
1264     }
1265 
1266     // For UpperBounds -----------
1267     //cout<<"MLCheck = "<<tree->mlCheck <<endl;
1268     if(tree->mlCheck == 0){
1269         tree->mlInitial = cur_lh;
1270     }
1271     // ---------------------------
1272 
1273 
1274     int i;
1275     //bool optimize_rate = true;
1276 //    double gradient_epsilon = min(logl_epsilon, 0.01); // epsilon for parameters starts at epsilon for logl
1277     for (i = 2; i < tree->params->num_param_iterations; i++) {
1278         double new_lh;
1279 
1280         // changed to opimise edge length first, and then Q,W,R inside the loop by Thomas on Sept 11, 15
1281         if (fixed_len == BRLEN_OPTIMIZE)
1282             new_lh = tree->optimizeAllBranches(min(i,3), logl_epsilon);  // loop only 3 times in total (previously in v0.9.6 5 times)
1283         else if (fixed_len == BRLEN_SCALE) {
1284             double scaling = 1.0;
1285             new_lh = tree->optimizeTreeLengthScaling(MIN_BRLEN_SCALE, scaling, MAX_BRLEN_SCALE, gradient_epsilon);
1286         } else
1287             new_lh = cur_lh;
1288 
1289         new_lh = optimizeParametersOnly(i, gradient_epsilon, new_lh);
1290 
1291         if (new_lh == 0.0) {
1292             if (fixed_len == BRLEN_OPTIMIZE)
1293                 cur_lh = tree->optimizeAllBranches(tree->params->num_param_iterations, logl_epsilon);
1294             else if (fixed_len == BRLEN_SCALE) {
1295                 double scaling = 1.0;
1296                 cur_lh = tree->optimizeTreeLengthScaling(MIN_BRLEN_SCALE, scaling, MAX_BRLEN_SCALE, gradient_epsilon);
1297             }
1298             break;
1299         }
1300         if (verbose_mode >= VB_MED) {
1301             model->writeInfo(cout);
1302             site_rate->writeInfo(cout);
1303             if (fixed_len == BRLEN_SCALE)
1304                 cout << "Scaled tree length: " << tree->treeLength() << endl;
1305         }
1306         if (new_lh > cur_lh + logl_epsilon) {
1307             cur_lh = new_lh;
1308             if (write_info)
1309                 cout << i << ". Current log-likelihood: " << cur_lh << endl;
1310         } else {
1311             site_rate->classifyRates(new_lh);
1312             if (fixed_len == BRLEN_OPTIMIZE)
1313                 cur_lh = tree->optimizeAllBranches(100, logl_epsilon);
1314             else if (fixed_len == BRLEN_SCALE) {
1315                 double scaling = 1.0;
1316                 cur_lh = tree->optimizeTreeLengthScaling(MIN_BRLEN_SCALE, scaling, MAX_BRLEN_SCALE, gradient_epsilon);
1317             }
1318             break;
1319         }
1320     }
1321 
1322     // normalize rates s.t. branch lengths are #subst per site
1323 //    if (Params::getInstance().optimize_alg_gammai != "EM")
1324     {
1325         double mean_rate = site_rate->rescaleRates();
1326         if (fabs(mean_rate-1.0) > 1e-6) {
1327             if (fixed_len == BRLEN_FIX)
1328                 outError("Fixing branch lengths not supported under specified site rate model");
1329             tree->scaleLength(mean_rate);
1330             tree->clearAllPartialLH();
1331         }
1332     }
1333 
1334     if (Params::getInstance().root_find && tree->rooted && Params::getInstance().root_move_dist > 0) {
1335         cur_lh = tree->optimizeRootPosition(Params::getInstance().root_move_dist, write_info, logl_epsilon);
1336         if (verbose_mode >= VB_MED || write_info)
1337             cout << "Rooting log-likelihood: " << cur_lh << endl;
1338     }
1339 
1340     if (verbose_mode >= VB_MED || write_info)
1341         cout << "Optimal log-likelihood: " << cur_lh << endl;
1342 
1343     // For UpperBounds -----------
1344     if(tree->mlCheck == 0)
1345         tree->mlFirstOpt = cur_lh;
1346     // ---------------------------
1347 
1348     if (verbose_mode <= VB_MIN && write_info) {
1349         model->writeInfo(cout);
1350         site_rate->writeInfo(cout);
1351         if (fixed_len == BRLEN_SCALE)
1352             cout << "Scaled tree length: " << tree->treeLength() << endl;
1353     }
1354     double elapsed_secs = getRealTime() - begin_time;
1355     if (write_info)
1356         cout << "Parameters optimization took " << i-1 << " rounds (" << elapsed_secs << " sec)" << endl;
1357     startStoringTransMatrix();
1358 
1359     // For UpperBounds -----------
1360     tree->mlCheck = 1;
1361     // ---------------------------
1362 
1363     tree->setCurScore(cur_lh);
1364     return cur_lh;
1365 }
1366 
1367 /**
1368  * @return TRUE if parameters are at the boundary that may cause numerical unstability
1369  */
isUnstableParameters()1370 bool ModelFactory::isUnstableParameters() {
1371     if (model->isUnstableParameters()) return true;
1372     return false;
1373 }
1374 
startStoringTransMatrix()1375 void ModelFactory::startStoringTransMatrix() {
1376     if (!store_trans_matrix) return;
1377     is_storing = true;
1378 }
1379 
stopStoringTransMatrix()1380 void ModelFactory::stopStoringTransMatrix() {
1381     if (!store_trans_matrix) return;
1382     is_storing = false;
1383     if (!empty()) {
1384         for (iterator it = begin(); it != end(); it++)
1385             delete it->second;
1386         clear();
1387     }
1388 }
1389 
1390 
computeTrans(double time,int state1,int state2)1391 double ModelFactory::computeTrans(double time, int state1, int state2) {
1392     return model->computeTrans(time, state1, state2);
1393 }
1394 
computeTrans(double time,int state1,int state2,double & derv1,double & derv2)1395 double ModelFactory::computeTrans(double time, int state1, int state2, double &derv1, double &derv2) {
1396     return model->computeTrans(time, state1, state2, derv1, derv2);
1397 }
1398 
computeTransMatrix(double time,double * trans_matrix,int mixture)1399 void ModelFactory::computeTransMatrix(double time, double *trans_matrix, int mixture) {
1400     if (!store_trans_matrix || !is_storing || model->isSiteSpecificModel()) {
1401         model->computeTransMatrix(time, trans_matrix, mixture);
1402         return;
1403     }
1404     int mat_size = model->num_states * model->num_states;
1405     iterator ass_it = find(round(time * 1e6));
1406     if (ass_it == end()) {
1407         // allocate memory for 3 matricies
1408         double *trans_entry = new double[mat_size * 3];
1409         trans_entry[mat_size] = trans_entry[mat_size+1] = 0.0;
1410         model->computeTransMatrix(time, trans_entry, mixture);
1411         ass_it = insert(value_type(round(time * 1e6), trans_entry)).first;
1412     } else {
1413         //if (verbose_mode >= VB_MAX)
1414             //cout << "ModelFactory bingo" << endl;
1415     }
1416 
1417     memcpy(trans_matrix, ass_it->second, mat_size * sizeof(double));
1418 }
1419 
computeTransDerv(double time,double * trans_matrix,double * trans_derv1,double * trans_derv2,int mixture)1420 void ModelFactory::computeTransDerv(double time, double *trans_matrix,
1421     double *trans_derv1, double *trans_derv2, int mixture) {
1422     if (!store_trans_matrix || !is_storing || model->isSiteSpecificModel()) {
1423         model->computeTransDerv(time, trans_matrix, trans_derv1, trans_derv2, mixture);
1424         return;
1425     }
1426     int mat_size = model->num_states * model->num_states;
1427     iterator ass_it = find(round(time * 1e6));
1428     if (ass_it == end()) {
1429         // allocate memory for 3 matricies
1430         double *trans_entry = new double[mat_size * 3];
1431         trans_entry[mat_size] = trans_entry[mat_size+1] = 0.0;
1432         model->computeTransDerv(time, trans_entry, trans_entry+mat_size, trans_entry+(mat_size*2), mixture);
1433         ass_it = insert(value_type(round(time * 1e6), trans_entry)).first;
1434     } else if (ass_it->second[mat_size] == 0.0 && ass_it->second[mat_size+1] == 0.0) {
1435         double *trans_entry = ass_it->second;
1436         model->computeTransDerv(time, trans_entry, trans_entry+mat_size, trans_entry+(mat_size*2), mixture);
1437     }
1438     memcpy(trans_matrix, ass_it->second, mat_size * sizeof(double));
1439     memcpy(trans_derv1, ass_it->second + mat_size, mat_size * sizeof(double));
1440     memcpy(trans_derv2, ass_it->second + (mat_size*2), mat_size * sizeof(double));
1441 }
1442 
~ModelFactory()1443 ModelFactory::~ModelFactory()
1444 {
1445     for (iterator it = begin(); it != end(); it++)
1446         delete it->second;
1447     clear();
1448 }
1449 
1450 /************* FOLLOWING SERVE FOR JOINT OPTIMIZATION OF MODEL AND RATE PARAMETERS *******/
getNDim()1451 int ModelFactory::getNDim()
1452 {
1453     return model->getNDim() + site_rate->getNDim();
1454 }
1455 
targetFunk(double x[])1456 double ModelFactory::targetFunk(double x[]) {
1457     model->getVariables(x);
1458     // need to compute rates again if p_inv or Gamma shape changes!
1459     if (model->state_freq[model->num_states-1] < MIN_RATE) return 1.0e+12;
1460     model->decomposeRateMatrix();
1461     site_rate->phylo_tree->clearAllPartialLH();
1462     return site_rate->targetFunk(x + model->getNDim());
1463 }
1464 
setVariables(double * variables)1465 void ModelFactory::setVariables(double *variables) {
1466     model->setVariables(variables);
1467     site_rate->setVariables(variables + model->getNDim());
1468 }
1469 
getVariables(double * variables)1470 bool ModelFactory::getVariables(double *variables) {
1471     bool changed = model->getVariables(variables);
1472     changed |= site_rate->getVariables(variables + model->getNDim());
1473     return changed;
1474 }
1475