1 /***************************************************************************
2 * Copyright (C) 2009 by BUI Quang Minh *
3 * minh.bui@univie.ac.at *
4 * *
5 * This program is free software; you can redistribute it and/or modify *
6 * it under the terms of the GNU General Public License as published by *
7 * the Free Software Foundation; either version 2 of the License, or *
8 * (at your option) any later version. *
9 * *
10 * This program is distributed in the hope that it will be useful, *
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13 * GNU General Public License for more details. *
14 * *
15 * You should have received a copy of the GNU General Public License *
16 * along with this program; if not, write to the *
17 * Free Software Foundation, Inc., *
18 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19 ***************************************************************************/
20 #include "rateinvar.h"
21 #include "modelfactory.h"
22 #include "rategamma.h"
23 #include "rategammainvar.h"
24 #include "modelmarkov.h"
25 #include "modelliemarkov.h"
26 #include "modeldna.h"
27 #include "modelprotein.h"
28 #include "modelbin.h"
29 #include "modelcodon.h"
30 #include "modelmorphology.h"
31 #include "modelpomo.h"
32 #include "modelset.h"
33 #include "modelmixture.h"
34 #include "ratemeyerhaeseler.h"
35 #include "ratemeyerdiscrete.h"
36 #include "ratekategory.h"
37 #include "ratefree.h"
38 #include "ratefreeinvar.h"
39 #include "rateheterotachy.h"
40 #include "rateheterotachyinvar.h"
41 //#include "ngs.h"
42 #include <string>
43 #include "utils/timeutil.h"
44 #include "nclextra/myreader.h"
45 #include <sstream>
46
findSubStr(string & name,string sub1,string sub2)47 string::size_type findSubStr(string &name, string sub1, string sub2) {
48 string::size_type pos1, pos2;
49 for (pos1 = 0; pos1 != string::npos; pos1++) {
50 pos1 = name.find(sub1, pos1);
51 if (pos1 == string::npos)
52 break;
53 if (pos1+2 >= name.length() || !isalpha(name[pos1+2])) {
54 break;
55 }
56 }
57
58 for (pos2 = 0; pos2 != string::npos; pos2++) {
59 pos2 = name.find(sub2, pos2);
60 if (pos2 == string::npos)
61 break;
62 if (pos2+2 >= name.length() ||!isalpha(name[pos2+2]))
63 break;
64 }
65
66 if (pos1 != string::npos && pos2 != string::npos) {
67 return min(pos1, pos2);
68 } else if (pos1 != string::npos)
69 return pos1;
70 else
71 return pos2;
72 }
73
posRateHeterotachy(string model_name)74 string::size_type posRateHeterotachy(string model_name) {
75 return findSubStr(model_name, "+H", "*H");
76 }
77
posRateFree(string & model_name)78 string::size_type posRateFree(string &model_name) {
79 return findSubStr(model_name, "+R", "*R");
80 }
81
posPOMO(string & model_name)82 string::size_type posPOMO(string &model_name) {
83 return findSubStr(model_name, "+P", "*P");
84 }
85
readModelsDefinition(Params & params)86 ModelsBlock *readModelsDefinition(Params ¶ms) {
87
88 ModelsBlock *models_block = new ModelsBlock;
89
90 try
91 {
92 // loading internal model definitions
93 stringstream in(builtin_mixmodels_definition);
94 ASSERT(in && "stringstream is OK");
95 NxsReader nexus;
96 nexus.Add(models_block);
97 MyToken token(in);
98 nexus.Execute(token);
99 } catch (...) {
100 ASSERT(0 && "predefined mixture models not initialized");
101 }
102
103 try
104 {
105 // loading internal protei model definitions
106 stringstream in(builtin_prot_models);
107 ASSERT(in && "stringstream is OK");
108 NxsReader nexus;
109 nexus.Add(models_block);
110 MyToken token(in);
111 nexus.Execute(token);
112 } catch (...) {
113 ASSERT(0 && "predefined protein models not initialized");
114 }
115
116 if (params.model_def_file) {
117 cout << "Reading model definition file " << params.model_def_file << " ... ";
118 MyReader nexus(params.model_def_file);
119 nexus.Add(models_block);
120 MyToken token(nexus.inf);
121 nexus.Execute(token);
122 int num_model = 0, num_freq = 0;
123 for (ModelsBlock::iterator it = models_block->begin(); it != models_block->end(); it++)
124 if (it->second.flag & NM_FREQ) num_freq++; else num_model++;
125 cout << num_model << " models and " << num_freq << " frequency vectors loaded" << endl;
126 }
127 return models_block;
128 }
129
ModelFactory()130 ModelFactory::ModelFactory() : CheckpointFactory() {
131 model = NULL;
132 site_rate = NULL;
133 store_trans_matrix = false;
134 is_storing = false;
135 joint_optimize = false;
136 fused_mix_rate = false;
137 ASC_type = ASC_NONE;
138 }
139
findCloseBracket(string & str,size_t start_pos)140 size_t findCloseBracket(string &str, size_t start_pos) {
141 int counter = 0;
142 for (size_t pos = start_pos+1; pos < str.length(); pos++) {
143 if (str[pos] == '{') counter++;
144 if (str[pos] == '}') {
145 if (counter == 0) return pos; else counter--;
146 }
147 }
148 return string::npos;
149 }
150
ModelFactory(Params & params,string & model_name,PhyloTree * tree,ModelsBlock * models_block)151 ModelFactory::ModelFactory(Params ¶ms, string &model_name, PhyloTree *tree, ModelsBlock *models_block) : CheckpointFactory() {
152 store_trans_matrix = params.store_trans_matrix;
153 is_storing = false;
154 joint_optimize = params.optimize_model_rate_joint;
155 fused_mix_rate = false;
156 ASC_type = ASC_NONE;
157 string model_str = model_name;
158 string rate_str;
159
160 try {
161
162
163 if (model_str == "") {
164 if (tree->aln->seq_type == SEQ_DNA) model_str = "HKY";
165 else if (tree->aln->seq_type == SEQ_PROTEIN) model_str = "LG";
166 else if (tree->aln->seq_type == SEQ_BINARY) model_str = "GTR2";
167 else if (tree->aln->seq_type == SEQ_CODON) model_str = "GY";
168 else if (tree->aln->seq_type == SEQ_MORPH) model_str = "MK";
169 else if (tree->aln->seq_type == SEQ_POMO) model_str = "HKY+P";
170 else model_str = "JC";
171 if (tree->aln->seq_type != SEQ_POMO && !params.model_joint)
172 outWarning("Default model "+model_str + " may be under-fitting. Use option '-m TEST' to determine the best-fit model.");
173 }
174
175 /********* preprocessing model string ****************/
176 NxsModel *nxsmodel = NULL;
177
178 string new_model_str = "";
179 size_t mix_pos;
180 for (mix_pos = 0; mix_pos < model_str.length(); mix_pos++) {
181 size_t next_mix_pos = model_str.find_first_of("+*", mix_pos);
182 string sub_model_str = model_str.substr(mix_pos, next_mix_pos-mix_pos);
183 nxsmodel = models_block->findMixModel(sub_model_str);
184 if (nxsmodel) sub_model_str = nxsmodel->description;
185 new_model_str += sub_model_str;
186 if (next_mix_pos != string::npos)
187 new_model_str += model_str[next_mix_pos];
188 else
189 break;
190 mix_pos = next_mix_pos;
191 }
192 if (new_model_str != model_str)
193 cout << "Model " << model_str << " is alias for " << new_model_str << endl;
194 model_str = new_model_str;
195
196 // nxsmodel = models_block->findModel(model_str);
197 // if (nxsmodel && nxsmodel->description.find_first_of("+*") != string::npos) {
198 // cout << "Model " << model_str << " is alias for " << nxsmodel->description << endl;
199 // model_str = nxsmodel->description;
200 // }
201
202 // Detect PoMo and throw error if sequence type is PoMo but +P is
203 // not given. This makes the model string cleaner and
204 // compareable.
205 string::size_type p_pos = posPOMO(model_str);
206 bool pomo = (p_pos != string::npos);
207
208 if ((p_pos == string::npos) &&
209 (tree->aln->seq_type == SEQ_POMO))
210 outError("Provided alignment is exclusively used by PoMo but model string does not contain, e.g., \"+P\".");
211
212 // Decompose model string into model_str and rate_str string.
213 size_t spec_pos = model_str.find_first_of("{+*");
214 if (spec_pos != string::npos) {
215 if (model_str[spec_pos] == '{') {
216 // Scan for the corresponding '}'.
217 size_t pos = findCloseBracket(model_str, spec_pos);
218 if (pos == string::npos)
219 outError("Model name has wrong bracket notation '{...}'");
220 rate_str = model_str.substr(pos+1);
221 model_str = model_str.substr(0, pos+1);
222 } else {
223 rate_str = model_str.substr(spec_pos);
224 model_str = model_str.substr(0, spec_pos);
225 }
226 }
227
228 // decompose +F from rate_str
229 string freq_str = "";
230 while ((spec_pos = rate_str.find("+F")) != string::npos) {
231 size_t end_pos = rate_str.find_first_of("+*", spec_pos+1);
232 if (end_pos == string::npos) {
233 freq_str += rate_str.substr(spec_pos);
234 rate_str = rate_str.substr(0, spec_pos);
235 } else {
236 freq_str += rate_str.substr(spec_pos, end_pos - spec_pos);
237 rate_str = rate_str.substr(0, spec_pos) + rate_str.substr(end_pos);
238 }
239 }
240
241 // set to model_joint if set
242 if (Params::getInstance().model_joint) {
243 model_str = Params::getInstance().model_joint;
244 freq_str = "";
245 while ((spec_pos = model_str.find("+F")) != string::npos) {
246 size_t end_pos = model_str.find_first_of("+*", spec_pos+1);
247 if (end_pos == string::npos) {
248 freq_str += model_str.substr(spec_pos);
249 model_str = model_str.substr(0, spec_pos);
250 } else {
251 freq_str += model_str.substr(spec_pos, end_pos - spec_pos);
252 model_str = model_str.substr(0, spec_pos) + model_str.substr(end_pos);
253 }
254 }
255 }
256
257 // PoMo; +NXX and +W or +S because those flags are handled when
258 // reading in the data. Set PoMo parameters (heterozygosity).
259 size_t n_pos_start = rate_str.find("+N");
260 size_t n_pos_end = rate_str.find_first_of("+", n_pos_start+1);
261 if (n_pos_start != string::npos) {
262 if (!pomo)
263 outError("Virtual population size can only be set with PoMo.");
264 if (n_pos_end != string::npos)
265 rate_str = rate_str.substr(0, n_pos_start)
266 + rate_str.substr(n_pos_end);
267 else
268 rate_str = rate_str.substr(0, n_pos_start);
269 }
270
271 size_t wb_pos = rate_str.find("+WB");
272 if (wb_pos != string::npos) {
273 if (!pomo)
274 outError("Weighted binomial sampling can only be used with PoMo.");
275 rate_str = rate_str.substr(0, wb_pos)
276 + rate_str.substr(wb_pos+3);
277 }
278 size_t wh_pos = rate_str.find("+WH");
279 if (wh_pos != string::npos) {
280 if (!pomo)
281 outError("Weighted hypergeometric sampling can only be used with PoMo.");
282 rate_str = rate_str.substr(0, wh_pos)
283 + rate_str.substr(wh_pos+3);
284 }
285 size_t s_pos = rate_str.find("+S");
286 if ( s_pos != string::npos) {
287 if (!pomo)
288 outError("Binomial sampling can only be used with PoMo.");
289 rate_str = rate_str.substr(0, s_pos)
290 + rate_str.substr(s_pos+2);
291 }
292
293 // In case of PoMo, check that only supported flags are given.
294 if (pomo) {
295 if (rate_str.find("+ASC") != string::npos)
296 // TODO DS: This is an important feature, because then,
297 // PoMo can be applied to SNP data only.
298 outError("PoMo does not yet support ascertainment bias correction (+ASC).");
299 if (posRateFree(rate_str) != string::npos)
300 outError("PoMo does not yet support free rate models (+R).");
301 if (rate_str.find("+FMIX") != string::npos)
302 outError("PoMo does not yet support frequency mixture models (+FMIX).");
303 if (posRateHeterotachy(rate_str) != string::npos)
304 outError("PoMo does not yet support heterotachy models (+H).");
305 }
306
307 // PoMo. The +P{}, +GXX and +I flags are interpreted during model creation.
308 // This is necessary for compatibility with mixture models. If there is no
309 // mixture model, move +P{}, +GXX and +I flags to model string. For mixture
310 // models, the heterozygosity can be set separately for each model and the
311 // +P{}, +GXX and +I flags should already be inside the model definition.
312 if (model_str.substr(0, 3) != "MIX" && pomo) {
313 // +P{} flag.
314 p_pos = posPOMO(rate_str);
315 if (p_pos != string::npos) {
316 if (rate_str[p_pos+2] == '{') {
317 string::size_type close_bracket = rate_str.find("}");
318 if (close_bracket == string::npos)
319 outError("No closing bracket in PoMo parameters.");
320 else {
321 string pomo_heterozygosity = rate_str.substr(p_pos+3,close_bracket-p_pos-3);
322 rate_str = rate_str.substr(0, p_pos) + rate_str.substr(close_bracket+1);
323 model_str += "+P{" + pomo_heterozygosity + "}";
324 }
325 }
326 else {
327 rate_str = rate_str.substr(0, p_pos) + rate_str.substr(p_pos + 2);
328 model_str += "+P";
329 }
330 }
331
332 // +G flag.
333 size_t pomo_rate_start_pos;
334 if ((pomo_rate_start_pos = rate_str.find("+G")) != string::npos) {
335 string pomo_rate_str = "";
336 size_t pomo_rate_end_pos = rate_str.find_first_of("+*", pomo_rate_start_pos+1);
337 if (pomo_rate_end_pos == string::npos) {
338 pomo_rate_str = rate_str.substr(pomo_rate_start_pos, rate_str.length() - pomo_rate_start_pos);
339 rate_str = rate_str.substr(0, pomo_rate_start_pos);
340 model_str += pomo_rate_str;
341 } else {
342 pomo_rate_str = rate_str.substr(pomo_rate_start_pos, pomo_rate_end_pos - pomo_rate_start_pos);
343 rate_str = rate_str.substr(0, pomo_rate_start_pos) + rate_str.substr(pomo_rate_end_pos);
344 model_str += pomo_rate_str;
345 }
346 }
347
348 // // +I flag.
349 // size_t pomo_irate_start_pos;
350 // if ((pomo_irate_start_pos = rate_str.find("+I")) != string::npos) {
351 // string pomo_irate_str = "";
352 // size_t pomo_irate_end_pos = rate_str.find_first_of("+*", pomo_irate_start_pos+1);
353 // if (pomo_irate_end_pos == string::npos) {
354 // pomo_irate_str = rate_str.substr(pomo_irate_start_pos, rate_str.length() - pomo_irate_start_pos);
355 // rate_str = rate_str.substr(0, pomo_irate_start_pos);
356 // model_str += pomo_irate_str;
357 // } else {
358 // pomo_irate_str = rate_str.substr(pomo_irate_start_pos, pomo_irate_end_pos - pomo_irate_start_pos);
359 // rate_str = rate_str.substr(0, pomo_irate_start_pos) + rate_str.substr(pomo_irate_end_pos);
360 // model_str += pomo_irate_str;
361 // }
362 }
363
364 // nxsmodel = models_block->findModel(model_str);
365 // if (nxsmodel && nxsmodel->description.find("MIX") != string::npos) {
366 // cout << "Model " << model_str << " is alias for " << nxsmodel->description << endl;
367 // model_str = nxsmodel->description;
368 // }
369
370 /******************** initialize state frequency ****************************/
371
372 StateFreqType freq_type = params.freq_type;
373
374 if (freq_type == FREQ_UNKNOWN) {
375 switch (tree->aln->seq_type) {
376 case SEQ_BINARY: freq_type = FREQ_ESTIMATE; break; // default for binary: optimized frequencies
377 case SEQ_PROTEIN: break; // let ModelProtein decide by itself
378 case SEQ_MORPH: freq_type = FREQ_EQUAL; break;
379 case SEQ_CODON: freq_type = FREQ_UNKNOWN; break;
380 break;
381 default: freq_type = FREQ_EMPIRICAL; break; // default for DNA, PoMo and others: counted frequencies from alignment
382 }
383 }
384
385 // first handle mixture frequency
386 string::size_type posfreq = freq_str.find("+FMIX");
387 string freq_params;
388 size_t close_bracket;
389
390 if (posfreq != string::npos) {
391 string fmix_str;
392 size_t last_pos = freq_str.find_first_of("+*", posfreq+1);
393
394 if (last_pos == string::npos) {
395 fmix_str = freq_str.substr(posfreq);
396 freq_str = freq_str.substr(0, posfreq);
397 } else {
398 fmix_str = freq_str.substr(posfreq, last_pos-posfreq);
399 freq_str = freq_str.substr(0, posfreq) + freq_str.substr(last_pos);
400 }
401
402 if (fmix_str[5] != OPEN_BRACKET)
403 outError("Mixture-frequency must start with +FMIX{");
404 close_bracket = fmix_str.find(CLOSE_BRACKET);
405 if (close_bracket == string::npos)
406 outError("Close bracket not found in ", fmix_str);
407 if (close_bracket != fmix_str.length()-1)
408 outError("Wrong close bracket position ", fmix_str);
409 freq_type = FREQ_MIXTURE;
410 freq_params = fmix_str.substr(6, close_bracket-6);
411 }
412
413 // then normal frequency
414 if (freq_str.find("+FO") != string::npos)
415 posfreq = freq_str.find("+FO");
416 else if (freq_str.find("+Fo") != string::npos)
417 posfreq = freq_str.find("+Fo");
418 else
419 posfreq = freq_str.find("+F");
420
421 bool optimize_mixmodel_weight = params.optimize_mixmodel_weight;
422
423 if (posfreq != string::npos) {
424 string fstr;
425 size_t last_pos = freq_str.find_first_of("+*", posfreq+1);
426 if (last_pos == string::npos) {
427 fstr = freq_str.substr(posfreq);
428 freq_str = freq_str.substr(0, posfreq);
429 } else {
430 fstr = freq_str.substr(posfreq, last_pos-posfreq);
431 freq_str = freq_str.substr(0, posfreq) + freq_str.substr(last_pos);
432 }
433
434 if (fstr.length() > 2 && fstr[2] == OPEN_BRACKET) {
435 if (freq_type == FREQ_MIXTURE)
436 outError("Mixture frequency with user-defined frequency is not allowed");
437 close_bracket = fstr.find(CLOSE_BRACKET);
438 if (close_bracket == string::npos)
439 outError("Close bracket not found in ", fstr);
440 if (close_bracket != fstr.length()-1)
441 outError("Wrong close bracket position ", fstr);
442 freq_type = FREQ_USER_DEFINED;
443 freq_params = fstr.substr(3, close_bracket-3);
444 } else if (fstr == "+FC" || fstr == "+Fc" || fstr == "+F") {
445 if (freq_type == FREQ_MIXTURE) {
446 freq_params = "empirical," + freq_params;
447 optimize_mixmodel_weight = true;
448 } else
449 freq_type = FREQ_EMPIRICAL;
450 } else if (fstr == "+FU" || fstr == "+Fu") {
451 if (freq_type == FREQ_MIXTURE)
452 outError("Mixture frequency with user-defined frequency is not allowed");
453 else
454 freq_type = FREQ_USER_DEFINED;
455 } else if (fstr == "+FQ" || fstr == "+Fq") {
456 if (freq_type == FREQ_MIXTURE)
457 outError("Mixture frequency with equal frequency is not allowed");
458 else
459 freq_type = FREQ_EQUAL;
460 } else if (fstr == "+FO" || fstr == "+Fo") {
461 if (freq_type == FREQ_MIXTURE) {
462 freq_params = "optimize," + freq_params;
463 optimize_mixmodel_weight = true;
464 } else
465 freq_type = FREQ_ESTIMATE;
466 } else if (fstr == "+F1x4" || fstr == "+F1X4") {
467 if (freq_type == FREQ_MIXTURE)
468 outError("Mixture frequency with " + fstr + " is not allowed");
469 else
470 freq_type = FREQ_CODON_1x4;
471 } else if (fstr == "+F3x4" || fstr == "+F3X4") {
472 if (freq_type == FREQ_MIXTURE)
473 outError("Mixture frequency with " + fstr + " is not allowed");
474 else
475 freq_type = FREQ_CODON_3x4;
476 } else if (fstr == "+F3x4C" || fstr == "+F3x4c" || fstr == "+F3X4C" || fstr == "+F3X4c") {
477 if (freq_type == FREQ_MIXTURE)
478 outError("Mixture frequency with " + fstr + " is not allowed");
479 else
480 freq_type = FREQ_CODON_3x4C;
481 } else if (fstr == "+FRY") {
482 // MDW to Minh: I don't know how these should interact with FREQ_MIXTURE,
483 // so as nearly everything else treats it as an error, I do too.
484 // BQM answer: that's fine
485 if (freq_type == FREQ_MIXTURE)
486 outError("Mixture frequency with " + fstr + " is not allowed");
487 else
488 freq_type = FREQ_DNA_RY;
489 } else if (fstr == "+FWS") {
490 if (freq_type == FREQ_MIXTURE)
491 outError("Mixture frequency with " + fstr + " is not allowed");
492 else
493 freq_type = FREQ_DNA_WS;
494 } else if (fstr == "+FMK") {
495 if (freq_type == FREQ_MIXTURE)
496 outError("Mixture frequency with " + fstr + " is not allowed");
497 else
498 freq_type = FREQ_DNA_MK;
499 } else {
500 // might be "+F####" where # are digits
501 try {
502 freq_type = parseStateFreqDigits(fstr.substr(2)); // throws an error if not in +F#### format
503 } catch (...) {
504 outError("Unknown state frequency type ",fstr);
505 }
506 }
507 // model_str = model_str.substr(0, posfreq);
508 }
509
510 /******************** initialize model ****************************/
511
512 if (tree->aln->site_state_freq.empty()) {
513 if (model_str.substr(0, 3) == "MIX" || freq_type == FREQ_MIXTURE) {
514 string model_list;
515 if (model_str.substr(0, 3) == "MIX") {
516 if (model_str[3] != OPEN_BRACKET)
517 outError("Mixture model name must start with 'MIX{'");
518 if (model_str.rfind(CLOSE_BRACKET) != model_str.length()-1)
519 outError("Close bracket not found at the end of ", model_str);
520 model_list = model_str.substr(4, model_str.length()-5);
521 model_str = model_str.substr(0, 3);
522 }
523 model = new ModelMixture(model_name, model_str, model_list, models_block, freq_type, freq_params, tree, optimize_mixmodel_weight);
524 } else {
525 // string model_desc;
526 // NxsModel *nxsmodel = models_block->findModel(model_str);
527 // if (nxsmodel) model_desc = nxsmodel->description;
528 model = createModel(model_str, models_block, freq_type, freq_params, tree);
529 }
530 // fused_mix_rate &= model->isMixture() && site_rate->getNRate() > 1;
531 } else {
532 // site-specific model
533 if (model_str == "JC" || model_str == "POISSON")
534 outError("JC is not suitable for site-specific model");
535 model = new ModelSet(model_str.c_str(), tree);
536 ModelSet *models = (ModelSet*)model; // assign pointer for convenience
537 models->init((params.freq_type != FREQ_UNKNOWN) ? params.freq_type : FREQ_EMPIRICAL);
538 int i;
539 models->pattern_model_map.resize(tree->aln->getNPattern(), -1);
540 for (i = 0; i < tree->aln->getNSite(); i++) {
541 models->pattern_model_map[tree->aln->getPatternID(i)] = tree->aln->site_model[i];
542 //cout << "site " << i << " ptn " << tree->aln->getPatternID(i) << " -> model " << site_model[i] << endl;
543 }
544 double *state_freq = new double[model->num_states];
545 double *rates = new double[model->getNumRateEntries()];
546 for (i = 0; i < tree->aln->site_state_freq.size(); i++) {
547 ModelMarkov *modeli;
548 if (i == 0) {
549 modeli = (ModelMarkov*)createModel(model_str, models_block, (params.freq_type != FREQ_UNKNOWN) ? params.freq_type : FREQ_EMPIRICAL, "", tree);
550 modeli->getStateFrequency(state_freq);
551 modeli->getRateMatrix(rates);
552 } else {
553 modeli = (ModelMarkov*)createModel(model_str, models_block, FREQ_EQUAL, "", tree);
554 modeli->setStateFrequency(state_freq);
555 modeli->setRateMatrix(rates);
556 }
557 if (tree->aln->site_state_freq[i])
558 modeli->setStateFrequency (tree->aln->site_state_freq[i]);
559
560 modeli->init(FREQ_USER_DEFINED);
561 models->push_back(modeli);
562 }
563 delete [] rates;
564 delete [] state_freq;
565
566 models->joinEigenMemory();
567 models->decomposeRateMatrix();
568
569 // delete information of the old alignment
570 // tree->aln->ordered_pattern.clear();
571 // tree->deleteAllPartialLh();
572 }
573
574 // if (model->isMixture())
575 // cout << "Mixture model with " << model->getNMixtures() << " components!" << endl;
576
577 /******************** initialize ascertainment bias correction model ****************************/
578
579 string::size_type posasc;
580
581 if ((posasc = rate_str.find("+ASC_INF")) != string::npos) {
582 // ascertainment bias correction
583 ASC_type = ASC_INFORMATIVE;
584 tree->aln->getUnobservedConstPatterns(ASC_type, unobserved_ptns);
585
586 // rebuild the seq_states to contain states of unobserved constant patterns
587 //tree->aln->buildSeqStates(model->seq_states, true);
588 if (tree->aln->num_informative_sites != tree->getAlnNSite()) {
589 if (!params.partition_file) {
590 string infsites_file = ((string)params.out_prefix + ".infsites.phy");
591 tree->aln->printAlignment(params.aln_output_format, infsites_file.c_str(), false, NULL, EXCLUDE_UNINF);
592 cerr << "For your convenience alignment with parsimony-informative sites printed to " << infsites_file << endl;
593 }
594 outError("Invalid use of +ASC_INF because of " + convertIntToString(tree->getAlnNSite() - tree->aln->num_informative_sites) +
595 " parsimony-uninformative sites in the alignment");
596 }
597 if (verbose_mode >= VB_MED)
598 cout << "Ascertainment bias correction: " << unobserved_ptns.size() << " unobservable uninformative patterns"<< endl;
599 rate_str = rate_str.substr(0, posasc) + rate_str.substr(posasc+8);
600 } else if ((posasc = rate_str.find("+ASC_MIS")) != string::npos) {
601 // initialize Holder's ascertainment bias correction model
602 ASC_type = ASC_VARIANT_MISSING;
603 tree->aln->getUnobservedConstPatterns(ASC_type, unobserved_ptns);
604 // rebuild the seq_states to contain states of unobserved constant patterns
605 //tree->aln->buildSeqStates(model->seq_states, true);
606 if (tree->aln->frac_invariant_sites > 0) {
607 if (!params.partition_file) {
608 string varsites_file = ((string)params.out_prefix + ".varsites.phy");
609 tree->aln->printAlignment(params.aln_output_format, varsites_file.c_str(), false, NULL, EXCLUDE_INVAR);
610 cerr << "For your convenience alignment with variable sites printed to " << varsites_file << endl;
611 }
612 outError("Invalid use of +ASC_MIS because of " + convertIntToString(tree->aln->frac_invariant_sites*tree->aln->getNSite()) +
613 " invariant sites in the alignment");
614 }
615 if (verbose_mode >= VB_MED)
616 cout << "Holder's ascertainment bias correction: " << unobserved_ptns.size() << " unobservable constant patterns" << endl;
617 rate_str = rate_str.substr(0, posasc) + rate_str.substr(posasc+8);
618 } else if ((posasc = rate_str.find("+ASC")) != string::npos) {
619 // ascertainment bias correction
620 ASC_type = ASC_VARIANT;
621 tree->aln->getUnobservedConstPatterns(ASC_type, unobserved_ptns);
622
623 // delete rarely observed state
624 for (int i = unobserved_ptns.size()-1; i >= 0; i--)
625 if (model->state_freq[(int)unobserved_ptns[i][0]] < 1e-8)
626 unobserved_ptns.erase(unobserved_ptns.begin() + i);
627
628 // rebuild the seq_states to contain states of unobserved constant patterns
629 //tree->aln->buildSeqStates(model->seq_states, true);
630 // if (unobserved_ptns.size() <= 0)
631 // outError("Invalid use of +ASC because all constant patterns are observed in the alignment");
632 if (tree->aln->frac_invariant_sites > 0) {
633 // cerr << tree->aln->frac_invariant_sites*tree->aln->getNSite() << " invariant sites observed in the alignment" << endl;
634 // for (Alignment::iterator pit = tree->aln->begin(); pit != tree->aln->end(); pit++)
635 // if (pit->isInvariant()) {
636 // string pat_str = "";
637 // for (Pattern::iterator it = pit->begin(); it != pit->end(); it++)
638 // pat_str += tree->aln->convertStateBackStr(*it);
639 // cerr << pat_str << " is invariant site pattern" << endl;
640 // }
641 if (!params.partition_file) {
642 string varsites_file = ((string)params.out_prefix + ".varsites.phy");
643 tree->aln->printAlignment(params.aln_output_format, varsites_file.c_str(), false, NULL, EXCLUDE_INVAR);
644 cerr << "For your convenience alignment with variable sites printed to " << varsites_file << endl;
645 }
646 outError("Invalid use of +ASC because of " + convertIntToString(tree->aln->frac_invariant_sites*tree->aln->getNSite()) +
647 " invariant sites in the alignment");
648 }
649 if (verbose_mode >= VB_MED)
650 cout << "Ascertainment bias correction: " << unobserved_ptns.size() << " unobservable constant patterns"<< endl;
651 rate_str = rate_str.substr(0, posasc) + rate_str.substr(posasc+4);
652 } else {
653 //tree->aln->buildSeqStates(model->seq_states, false);
654 }
655
656 /******************** initialize site rate heterogeneity ****************************/
657
658 string::size_type posI = rate_str.find("+I");
659 string::size_type posG = rate_str.find("+G");
660 string::size_type posG2 = rate_str.find("*G");
661 if (posG != string::npos && posG2 != string::npos) {
662 cout << "NOTE: both +G and *G were specified, continue with "
663 << ((posG < posG2)? rate_str.substr(posG,2) : rate_str.substr(posG2,2)) << endl;
664 }
665 if (posG2 != string::npos && posG2 < posG) {
666 posG = posG2;
667 fused_mix_rate = true;
668 }
669
670 string::size_type posR = rate_str.find("+R"); // FreeRate model
671 string::size_type posR2 = rate_str.find("*R"); // FreeRate model
672
673 if (posG != string::npos && (posR != string::npos || posR2 != string::npos)) {
674 outWarning("Both Gamma and FreeRate models were specified, continue with FreeRate model");
675 posG = string::npos;
676 fused_mix_rate = false;
677 }
678
679 if (posR != string::npos && posR2 != string::npos) {
680 cout << "NOTE: both +R and *R were specified, continue with "
681 << ((posR < posR2)? rate_str.substr(posR,2) : rate_str.substr(posR2,2)) << endl;
682 }
683
684 if (posR2 != string::npos && posR2 < posR) {
685 posR = posR2;
686 fused_mix_rate = true;
687 }
688
689 string::size_type posH = rate_str.find("+H"); // heterotachy model
690 string::size_type posH2 = rate_str.find("*H"); // heterotachy model
691
692 if (posG != string::npos && (posH != string::npos || posH2 != string::npos)) {
693 outWarning("Both Gamma and heterotachy models were specified, continue with heterotachy model");
694 posG = string::npos;
695 fused_mix_rate = false;
696 }
697
698 if (posR != string::npos && (posH != string::npos || posH2 != string::npos)) {
699 outWarning("Both FreeRate and heterotachy models were specified, continue with heterotachy model");
700 posR = string::npos;
701 fused_mix_rate = false;
702 }
703
704 if (posH != string::npos && posH2 != string::npos) {
705 cout << "NOTE: both +H and *H were specified, continue with "
706 << ((posH < posH2)? rate_str.substr(posH,2) : rate_str.substr(posH2,2)) << endl;
707 }
708 if (posH2 != string::npos && posH2 < posH) {
709 posH = posH2;
710 fused_mix_rate = true;
711 }
712
713 string::size_type posX;
714 /* create site-rate heterogeneity */
715 int num_rate_cats = params.num_rate_cats;
716 if (fused_mix_rate && model->isMixture()) num_rate_cats = model->getNMixtures();
717 double gamma_shape = params.gamma_shape;
718 double p_invar_sites = params.p_invar_sites;
719 string freerate_params = "";
720 if (posI != string::npos) {
721 // invariable site model
722 if (rate_str.length() > posI+2 && rate_str[posI+2] == OPEN_BRACKET) {
723 close_bracket = rate_str.find(CLOSE_BRACKET, posI);
724 if (close_bracket == string::npos)
725 outError("Close bracket not found in ", rate_str);
726 p_invar_sites = convert_double(rate_str.substr(posI+3, close_bracket-posI-3).c_str());
727 if (p_invar_sites < 0 || p_invar_sites >= 1)
728 outError("p_invar must be in [0,1)");
729 } else if (rate_str.length() > posI+2 && rate_str[posI+2] != '+' && rate_str[posI+2] != '*')
730 outError("Wrong model name ", rate_str);
731 }
732 if (posG != string::npos) {
733 // Gamma rate model
734 int end_pos = 0;
735 if (rate_str.length() > posG+2 && isdigit(rate_str[posG+2])) {
736 num_rate_cats = convert_int(rate_str.substr(posG+2).c_str(), end_pos);
737 if (num_rate_cats < 1) outError("Wrong number of rate categories");
738 }
739 if (rate_str.length() > posG+2+end_pos && rate_str[posG+2+end_pos] == OPEN_BRACKET) {
740 close_bracket = rate_str.find(CLOSE_BRACKET, posG);
741 if (close_bracket == string::npos)
742 outError("Close bracket not found in ", rate_str);
743 gamma_shape = convert_double(rate_str.substr(posG+3+end_pos, close_bracket-posG-3-end_pos).c_str());
744 // if (gamma_shape < MIN_GAMMA_SHAPE || gamma_shape > MAX_GAMMA_SHAPE) {
745 // stringstream str;
746 // str << "Gamma shape parameter " << gamma_shape << "out of range ["
747 // << MIN_GAMMA_SHAPE << ',' << MAX_GAMMA_SHAPE << "]" << endl;
748 // outError(str.str());
749 // }
750 } else if (rate_str.length() > posG+2+end_pos && rate_str[posG+2+end_pos] != '+')
751 outError("Wrong model name ", rate_str);
752 }
753 if (posR != string::npos) {
754 // FreeRate model
755 int end_pos = 0;
756 if (rate_str.length() > posR+2 && isdigit(rate_str[posR+2])) {
757 num_rate_cats = convert_int(rate_str.substr(posR+2).c_str(), end_pos);
758 if (num_rate_cats < 1) outError("Wrong number of rate categories");
759 }
760 if (rate_str.length() > posR+2+end_pos && rate_str[posR+2+end_pos] == OPEN_BRACKET) {
761 close_bracket = rate_str.find(CLOSE_BRACKET, posR);
762 if (close_bracket == string::npos)
763 outError("Close bracket not found in ", rate_str);
764 freerate_params = rate_str.substr(posR+3+end_pos, close_bracket-posR-3-end_pos).c_str();
765 } else if (rate_str.length() > posR+2+end_pos && rate_str[posR+2+end_pos] != '+')
766 outError("Wrong model name ", rate_str);
767 }
768
769 string heterotachy_params = "";
770 if (posH != string::npos) {
771 // Heterotachy model
772 int end_pos = 0;
773 if (rate_str.length() > posH+2 && isdigit(rate_str[posH+2])) {
774 num_rate_cats = convert_int(rate_str.substr(posH+2).c_str(), end_pos);
775 if (num_rate_cats < 1) outError("Wrong number of rate categories");
776 } else {
777 if (!model->isMixture() || !fused_mix_rate)
778 outError("Please specify number of heterotachy classes (e.g., +H2)");
779 }
780 if (rate_str.length() > posH+2+end_pos && rate_str[posH+2+end_pos] == OPEN_BRACKET) {
781 close_bracket = rate_str.find(CLOSE_BRACKET, posH);
782 if (close_bracket == string::npos)
783 outError("Close bracket not found in ", rate_str);
784 heterotachy_params = rate_str.substr(posH+3+end_pos, close_bracket-posH-3-end_pos).c_str();
785 } else if (rate_str.length() > posH+2+end_pos && rate_str[posH+2+end_pos] != '+')
786 outError("Wrong model name ", rate_str);
787 }
788
789
790 if (rate_str.find('+') != string::npos || rate_str.find('*') != string::npos) {
791 //string rate_str = model_str.substr(pos);
792 if (posI != string::npos && posH != string::npos) {
793 site_rate = new RateHeterotachyInvar(num_rate_cats, heterotachy_params, p_invar_sites, tree);
794 } else if (posH != string::npos) {
795 site_rate = new RateHeterotachy(num_rate_cats, heterotachy_params, tree);
796 } else if (posI != string::npos && posG != string::npos) {
797 site_rate = new RateGammaInvar(num_rate_cats, gamma_shape, params.gamma_median,
798 p_invar_sites, params.optimize_alg_gammai, tree, false);
799 } else if (posI != string::npos && posR != string::npos) {
800 site_rate = new RateFreeInvar(num_rate_cats, gamma_shape, freerate_params, !fused_mix_rate, p_invar_sites, params.optimize_alg, tree);
801 } else if (posI != string::npos) {
802 site_rate = new RateInvar(p_invar_sites, tree);
803 } else if (posG != string::npos) {
804 site_rate = new RateGamma(num_rate_cats, gamma_shape, params.gamma_median, tree);
805 } else if (posR != string::npos) {
806 site_rate = new RateFree(num_rate_cats, gamma_shape, freerate_params, !fused_mix_rate, params.optimize_alg, tree);
807 // } else if ((posX = rate_str.find("+M")) != string::npos) {
808 // tree->setLikelihoodKernel(LK_NORMAL);
809 // params.rate_mh_type = true;
810 // if (rate_str.length() > posX+2 && isdigit(rate_str[posX+2])) {
811 // num_rate_cats = convert_int(rate_str.substr(posX+2).c_str());
812 // if (num_rate_cats < 0) outError("Wrong number of rate categories");
813 // } else num_rate_cats = -1;
814 // if (num_rate_cats >= 0)
815 // site_rate = new RateMeyerDiscrete(num_rate_cats, params.mcat_type,
816 // params.rate_file, tree, params.rate_mh_type);
817 // else
818 // site_rate = new RateMeyerHaeseler(params.rate_file, tree, params.rate_mh_type);
819 // site_rate->setTree(tree);
820 // } else if ((posX = rate_str.find("+D")) != string::npos) {
821 // tree->setLikelihoodKernel(LK_NORMAL);
822 // params.rate_mh_type = false;
823 // if (rate_str.length() > posX+2 && isdigit(rate_str[posX+2])) {
824 // num_rate_cats = convert_int(rate_str.substr(posX+2).c_str());
825 // if (num_rate_cats < 0) outError("Wrong number of rate categories");
826 // } else num_rate_cats = -1;
827 // if (num_rate_cats >= 0)
828 // site_rate = new RateMeyerDiscrete(num_rate_cats, params.mcat_type,
829 // params.rate_file, tree, params.rate_mh_type);
830 // else
831 // site_rate = new RateMeyerHaeseler(params.rate_file, tree, params.rate_mh_type);
832 // site_rate->setTree(tree);
833 // } else if ((posX = rate_str.find("+NGS")) != string::npos) {
834 // tree->setLikelihoodKernel(LK_NORMAL);
835 // if (rate_str.length() > posX+4 && isdigit(rate_str[posX+4])) {
836 // num_rate_cats = convert_int(rate_str.substr(posX+4).c_str());
837 // if (num_rate_cats < 0) outError("Wrong number of rate categories");
838 // } else num_rate_cats = -1;
839 // site_rate = new NGSRateCat(tree, num_rate_cats);
840 // site_rate->setTree(tree);
841 // } else if ((posX = rate_str.find("+NGS")) != string::npos) {
842 // tree->setLikelihoodKernel(LK_NORMAL);
843 // if (rate_str.length() > posX+4 && isdigit(rate_str[posX+4])) {
844 // num_rate_cats = convert_int(rate_str.substr(posX+4).c_str());
845 // if (num_rate_cats < 0) outError("Wrong number of rate categories");
846 // } else num_rate_cats = -1;
847 // site_rate = new NGSRate(tree);
848 // site_rate->setTree(tree);
849 } else if ((posX = rate_str.find("+K")) != string::npos) {
850 if (rate_str.length() > posX+2 && isdigit(rate_str[posX+2])) {
851 num_rate_cats = convert_int(rate_str.substr(posX+2).c_str());
852 if (num_rate_cats < 1) outError("Wrong number of rate categories");
853 }
854 site_rate = new RateKategory(num_rate_cats, tree);
855 } else
856 outError("Invalid rate heterogeneity type");
857 // if (model_str.find('+') != string::npos)
858 // model_str = model_str.substr(0, model_str.find('+'));
859 // else
860 // model_str = model_str.substr(0, model_str.find('*'));
861 } else {
862 site_rate = new RateHeterogeneity();
863 site_rate->setTree(tree);
864 }
865
866 if (fused_mix_rate) {
867 if (!model->isMixture()) {
868 if (verbose_mode >= VB_MED)
869 cout << endl << "NOTE: Using mixture model with unlinked " << model_str << " parameters" << endl;
870 string model_list = model_str;
871 delete model;
872 for (int i = 1; i < site_rate->getNRate(); i++)
873 model_list += "," + model_str;
874 model = new ModelMixture(model_name, model_str, model_list, models_block, freq_type, freq_params, tree, optimize_mixmodel_weight);
875 }
876 if (model->getNMixtures() != site_rate->getNRate())
877 outError("Mixture model and site rate model do not have the same number of categories");
878 // if (!tree->isMixlen()) {
879 // reset mixture model
880 model->setFixMixtureWeight(true);
881 int mix, nmix = model->getNMixtures();
882 for (mix = 0; mix < nmix; mix++) {
883 ((ModelMarkov*)model->getMixtureClass(mix))->total_num_subst = 1.0;
884 model->setMixtureWeight(mix, 1.0);
885 }
886 model->decomposeRateMatrix();
887 // } else {
888 // site_rate->setFixParams(1);
889 // int c, ncat = site_rate->getNRate();
890 // for (c = 0; c < ncat; c++)
891 // site_rate->setProp(c, 1.0);
892 // }
893 }
894
895 tree->discardSaturatedSite(params.discard_saturated_site);
896
897 } catch (const char* str) {
898 outError(str);
899 }
900
901 }
902
setCheckpoint(Checkpoint * checkpoint)903 void ModelFactory::setCheckpoint(Checkpoint *checkpoint) {
904 CheckpointFactory::setCheckpoint(checkpoint);
905 model->setCheckpoint(checkpoint);
906 site_rate->setCheckpoint(checkpoint);
907 }
908
startCheckpoint()909 void ModelFactory::startCheckpoint() {
910 checkpoint->startStruct("ModelFactory");
911 }
912
saveCheckpoint()913 void ModelFactory::saveCheckpoint() {
914 model->saveCheckpoint();
915 site_rate->saveCheckpoint();
916 startCheckpoint();
917 // CKP_SAVE(fused_mix_rate);
918 // CKP_SAVE(unobserved_ptns);
919 // CKP_SAVE(joint_optimize);
920 endCheckpoint();
921 CheckpointFactory::saveCheckpoint();
922 }
923
restoreCheckpoint()924 void ModelFactory::restoreCheckpoint() {
925 model->restoreCheckpoint();
926 site_rate->restoreCheckpoint();
927 startCheckpoint();
928 // CKP_RESTORE(fused_mix_rate);
929 // CKP_RESTORE(unobserved_ptns);
930 // CKP_RESTORE(joint_optimize);
931 endCheckpoint();
932 }
933
getNParameters(int brlen_type)934 int ModelFactory::getNParameters(int brlen_type) {
935 int df = model->getNDim() + model->getNDimFreq() + site_rate->getNDim() +
936 site_rate->getTree()->getNBranchParameters(brlen_type);
937
938 return df;
939 }
940 /*
941 double ModelFactory::initGTRGammaIParameters(RateHeterogeneity *rate, ModelSubst *model, double initAlpha,
942 double initPInvar, double *initRates, double *initStateFreqs) {
943
944 RateHeterogeneity* rateGammaInvar = rate;
945 ModelMarkov* modelGTR = (ModelMarkov*)(model);
946 modelGTR->setRateMatrix(initRates);
947 modelGTR->setStateFrequency(initStateFreqs);
948 rateGammaInvar->setGammaShape(initAlpha);
949 rateGammaInvar->setPInvar(initPInvar);
950 modelGTR->decomposeRateMatrix();
951 site_rate->phylo_tree->clearAllPartialLH();
952 return site_rate->phylo_tree->computeLikelihood();
953 }
954 */
955
optimizeParametersOnly(int num_steps,double gradient_epsilon,double cur_logl)956 double ModelFactory::optimizeParametersOnly(int num_steps, double gradient_epsilon, double cur_logl) {
957 double logl;
958 /* Optimize substitution and heterogeneity rates independently */
959 if (!joint_optimize) {
960 // more steps for fused mix rate model
961 int steps;
962 if (false && fused_mix_rate && model->getNDim() > 0 && site_rate->getNDim() > 0) {
963 model->setOptimizeSteps(1);
964 site_rate->setOptimizeSteps(1);
965 steps = max(model->getNDim()+site_rate->getNDim(), num_steps) * 3;
966 } else {
967 steps = 1;
968 }
969 double prev_logl = cur_logl;
970 for (int step = 0; step < steps; step++) {
971 double model_lh = 0.0;
972 // only optimized if model is not linked
973 model_lh = model->optimizeParameters(gradient_epsilon);
974
975 double rate_lh = site_rate->optimizeParameters(gradient_epsilon);
976
977 if (rate_lh == 0.0)
978 logl = model_lh;
979 else
980 logl = rate_lh;
981 if (logl <= prev_logl + gradient_epsilon)
982 break;
983 prev_logl = logl;
984 }
985 } else {
986 /* Optimize substitution and heterogeneity rates jointly using BFGS */
987 logl = optimizeAllParameters(gradient_epsilon);
988 }
989 return logl;
990 }
991
optimizeAllParameters(double gradient_epsilon)992 double ModelFactory::optimizeAllParameters(double gradient_epsilon) {
993 int ndim = getNDim();
994
995 // return if nothing to be optimized
996 if (ndim == 0) return 0.0;
997
998 double *variables = new double[ndim+1];
999 double *upper_bound = new double[ndim+1];
1000 double *lower_bound = new double[ndim+1];
1001 bool *bound_check = new bool[ndim+1];
1002 int i;
1003 double score;
1004
1005 // setup the bounds for model
1006 setVariables(variables);
1007 int model_ndim = model->getNDim();
1008 for (i = 1; i <= model_ndim; i++) {
1009 //cout << variables[i] << endl;
1010 lower_bound[i] = MIN_RATE;
1011 upper_bound[i] = MAX_RATE;
1012 bound_check[i] = false;
1013 }
1014
1015 if (model->freq_type == FREQ_ESTIMATE) {
1016 for (i = model_ndim- model->num_states+2; i <= model_ndim; i++)
1017 upper_bound[i] = 1.0;
1018 }
1019
1020 // setup the bounds for site_rate
1021 site_rate->setBounds(lower_bound+model_ndim, upper_bound+model_ndim, bound_check+model_ndim);
1022
1023 score = -minimizeMultiDimen(variables, ndim, lower_bound, upper_bound, bound_check, max(gradient_epsilon, TOL_RATE));
1024
1025 getVariables(variables);
1026 //if (freq_type == FREQ_ESTIMATE) scaleStateFreq(true);
1027 model->decomposeRateMatrix();
1028 site_rate->phylo_tree->clearAllPartialLH();
1029
1030 score = site_rate->phylo_tree->computeLikelihood();
1031
1032 delete [] bound_check;
1033 delete [] lower_bound;
1034 delete [] upper_bound;
1035 delete [] variables;
1036
1037 return score;
1038 }
1039
optimizeParametersGammaInvar(int fixed_len,bool write_info,double logl_epsilon,double gradient_epsilon)1040 double ModelFactory::optimizeParametersGammaInvar(int fixed_len, bool write_info, double logl_epsilon, double gradient_epsilon) {
1041 if (!site_rate->isGammai() || site_rate->isFixPInvar() || site_rate->isFixGammaShape() || site_rate->getTree()->aln->frac_const_sites == 0.0 || model->isMixture())
1042 return optimizeParameters(fixed_len, write_info, logl_epsilon, gradient_epsilon);
1043
1044 double begin_time = getRealTime();
1045
1046 PhyloTree *tree = site_rate->getTree();
1047 double frac_const = tree->aln->frac_const_sites;
1048 tree->setCurScore(tree->computeLikelihood());
1049
1050 /* Back up branch lengths and substitutional rates */
1051 DoubleVector initBranLens;
1052 DoubleVector bestLens;
1053 tree->saveBranchLengths(initBranLens);
1054 bestLens = initBranLens;
1055 // int numRateEntries = tree->getModel()->getNumRateEntries();
1056 Checkpoint *model_ckp = new Checkpoint;
1057 Checkpoint *best_ckp = new Checkpoint;
1058 Checkpoint *saved_ckp = model->getCheckpoint();
1059 *model_ckp = *saved_ckp;
1060 // double *rates = new double[numRateEntries];
1061 // double *bestRates = new double[numRateEntries];
1062 // tree->getModel()->getRateMatrix(rates);
1063 // int numStates = tree->aln->num_states;
1064 // double *state_freqs = new double[numStates];
1065 // tree->getModel()->getStateFrequency(state_freqs);
1066
1067 /* Best estimates found */
1068 // double *bestStateFreqs = new double[numStates];
1069 double bestLogl = -DBL_MAX;
1070 double bestAlpha = 0.0;
1071 double bestPInvar = 0.0;
1072
1073 double testInterval = (frac_const - MIN_PINVAR * 2) / 9;
1074 double initPInv = MIN_PINVAR;
1075 double initAlpha = site_rate->getGammaShape();
1076
1077 if (Params::getInstance().opt_gammai_fast) {
1078 initPInv = frac_const/2;
1079 bool stop = false;
1080 while(!stop) {
1081 if (write_info) {
1082 cout << endl;
1083 cout << "Testing with init. pinv = " << initPInv << " / init. alpha = " << initAlpha << endl;
1084 }
1085
1086 vector<double> estResults = optimizeGammaInvWithInitValue(fixed_len, logl_epsilon, gradient_epsilon,
1087 initPInv, initAlpha, initBranLens, model_ckp);
1088
1089
1090 if (write_info) {
1091 cout << "Est. p_inv: " << estResults[0] << " / Est. gamma shape: " << estResults[1]
1092 << " / Logl: " << estResults[2] << endl;
1093 }
1094
1095 if (estResults[2] > bestLogl) {
1096 bestLogl = estResults[2];
1097 bestAlpha = estResults[1];
1098 bestPInvar = estResults[0];
1099 bestLens.clear();
1100 tree->saveBranchLengths(bestLens);
1101 model->setCheckpoint(best_ckp);
1102 model->saveCheckpoint();
1103 model->setCheckpoint(saved_ckp);
1104 // *best_ckp = *saved_ckp;
1105
1106 // tree->getModel()->getRateMatrix(bestRates);
1107 // tree->getModel()->getStateFrequency(bestStateFreqs);
1108 if (estResults[0] < initPInv) {
1109 initPInv = estResults[0] - testInterval;
1110 if (initPInv < 0.0)
1111 initPInv = 0.0;
1112 } else {
1113 initPInv = estResults[0] + testInterval;
1114 if (initPInv > frac_const)
1115 initPInv = frac_const;
1116 }
1117 //cout << "New initPInv = " << initPInv << endl;
1118 } else {
1119 stop = true;
1120 }
1121 }
1122 } else {
1123 // Now perform testing different initial p_inv values
1124 if (write_info)
1125 cout << "Thoroughly optimizing +I+G parameters from 10 start values..." << endl;
1126 while (initPInv <= frac_const) {
1127 vector<double> estResults; // vector of p_inv, alpha and logl
1128 if (Params::getInstance().opt_gammai_keep_bran)
1129 estResults = optimizeGammaInvWithInitValue(fixed_len, logl_epsilon, gradient_epsilon,
1130 initPInv, initAlpha, bestLens, model_ckp);
1131 else
1132 estResults = optimizeGammaInvWithInitValue(fixed_len, logl_epsilon, gradient_epsilon,
1133 initPInv, initAlpha, initBranLens, model_ckp);
1134 if (write_info) {
1135 cout << "Init pinv, alpha: " << initPInv << ", " << initAlpha
1136 << " / Estimate: " << estResults[0] << ", " << estResults[1]
1137 << " / LogL: " << estResults[2] << endl;
1138 }
1139
1140 initPInv = initPInv + testInterval;
1141
1142 if (estResults[2] > bestLogl) {
1143 bestLogl = estResults[2];
1144 bestAlpha = estResults[1];
1145 bestPInvar = estResults[0];
1146 bestLens.clear();
1147 tree->saveBranchLengths(bestLens);
1148 model->setCheckpoint(best_ckp);
1149 model->saveCheckpoint();
1150 model->setCheckpoint(saved_ckp);
1151 // *best_ckp = *saved_ckp;
1152
1153 // tree->getModel()->getRateMatrix(bestRates);
1154 // tree->getModel()->getStateFrequency(bestStateFreqs);
1155 }
1156 }
1157 }
1158
1159 site_rate->setGammaShape(bestAlpha);
1160 site_rate->setPInvar(bestPInvar);
1161
1162 // -- Mon Apr 17 21:12:14 BST 2017
1163 // DONE Minh, merged correctly
1164 model->setCheckpoint(best_ckp);
1165 model->restoreCheckpoint();
1166 model->setCheckpoint(saved_ckp);
1167 // ((ModelGTR*) tree->getModel())->setRateMatrix(bestRates);
1168 // ((ModelGTR*) tree->getModel())->setStateFrequency(bestStateFreqs);
1169 // --
1170
1171 tree->restoreBranchLengths(bestLens);
1172 // tree->getModel()->decomposeRateMatrix();
1173
1174 tree->clearAllPartialLH();
1175 tree->setCurScore(tree->computeLikelihood());
1176 if (write_info) {
1177 cout << "Optimal pinv,alpha: " << bestPInvar << ", " << bestAlpha << " / ";
1178 cout << "LogL: " << tree->getCurScore() << endl << endl;
1179 }
1180 ASSERT(fabs(tree->getCurScore() - bestLogl) < 1.0);
1181
1182 // delete [] rates;
1183 // delete [] state_freqs;
1184 // delete [] bestRates;
1185 // delete [] bestStateFreqs;
1186
1187 delete model_ckp;
1188 delete best_ckp;
1189
1190 double elapsed_secs = getRealTime() - begin_time;
1191 if (write_info)
1192 cout << "Parameters optimization took " << elapsed_secs << " sec" << endl;
1193
1194 // updating global variable is not safe!
1195 // Params::getInstance().testAlpha = false;
1196
1197 // 2016-03-14: this was missing!
1198 return tree->getCurScore();
1199 }
1200
optimizeGammaInvWithInitValue(int fixed_len,double logl_epsilon,double gradient_epsilon,double initPInv,double initAlpha,DoubleVector & lenvec,Checkpoint * model_ckp)1201 vector<double> ModelFactory::optimizeGammaInvWithInitValue(int fixed_len, double logl_epsilon, double gradient_epsilon,
1202 double initPInv, double initAlpha,
1203 DoubleVector &lenvec, Checkpoint *model_ckp) {
1204 PhyloTree *tree = site_rate->getTree();
1205 tree->restoreBranchLengths(lenvec);
1206
1207 // -- Mon Apr 17 21:12:24 BST 2017
1208 // DONE Minh: merged correctly
1209 Checkpoint *saved_ckp = model->getCheckpoint();
1210 model->setCheckpoint(model_ckp);
1211 model->restoreCheckpoint();
1212 model->setCheckpoint(saved_ckp);
1213 site_rate->setPInvar(initPInv);
1214 site_rate->setGammaShape(initAlpha);
1215 // --
1216
1217 tree->clearAllPartialLH();
1218 optimizeParameters(fixed_len, false, logl_epsilon, gradient_epsilon);
1219
1220 vector<double> estResults;
1221 double estPInv = site_rate->getPInvar();
1222 double estAlpha = site_rate->getGammaShape();
1223 double logl = tree->getCurScore();
1224 estResults.push_back(estPInv);
1225 estResults.push_back(estAlpha);
1226 estResults.push_back(logl);
1227 return estResults;
1228 }
1229
1230
optimizeParameters(int fixed_len,bool write_info,double logl_epsilon,double gradient_epsilon)1231 double ModelFactory::optimizeParameters(int fixed_len, bool write_info,
1232 double logl_epsilon, double gradient_epsilon) {
1233 ASSERT(model);
1234 ASSERT(site_rate);
1235
1236 // double defaultEpsilon = logl_epsilon;
1237
1238 double begin_time = getRealTime();
1239 double cur_lh;
1240 PhyloTree *tree = site_rate->getTree();
1241 ASSERT(tree);
1242
1243 stopStoringTransMatrix();
1244 // modified by Thomas Wong on Sept 11, 15
1245 // no optimization of branch length in the first round
1246 cur_lh = tree->computeLikelihood();
1247 tree->setCurScore(cur_lh);
1248 if (verbose_mode >= VB_MED || write_info) {
1249 int p = -1;
1250
1251 // SET precision to 17 (temporarily)
1252 if (verbose_mode >= VB_DEBUG) p = cout.precision(17);
1253
1254 // PRINT Log-Likelihood
1255 cout << "1. Initial log-likelihood: " << cur_lh << endl;
1256
1257 // RESTORE previous precision
1258 if (verbose_mode >= VB_DEBUG) cout.precision(p);
1259
1260 if (verbose_mode >= VB_MAX) {
1261 tree->printTree(cout);
1262 cout << endl;
1263 }
1264 }
1265
1266 // For UpperBounds -----------
1267 //cout<<"MLCheck = "<<tree->mlCheck <<endl;
1268 if(tree->mlCheck == 0){
1269 tree->mlInitial = cur_lh;
1270 }
1271 // ---------------------------
1272
1273
1274 int i;
1275 //bool optimize_rate = true;
1276 // double gradient_epsilon = min(logl_epsilon, 0.01); // epsilon for parameters starts at epsilon for logl
1277 for (i = 2; i < tree->params->num_param_iterations; i++) {
1278 double new_lh;
1279
1280 // changed to opimise edge length first, and then Q,W,R inside the loop by Thomas on Sept 11, 15
1281 if (fixed_len == BRLEN_OPTIMIZE)
1282 new_lh = tree->optimizeAllBranches(min(i,3), logl_epsilon); // loop only 3 times in total (previously in v0.9.6 5 times)
1283 else if (fixed_len == BRLEN_SCALE) {
1284 double scaling = 1.0;
1285 new_lh = tree->optimizeTreeLengthScaling(MIN_BRLEN_SCALE, scaling, MAX_BRLEN_SCALE, gradient_epsilon);
1286 } else
1287 new_lh = cur_lh;
1288
1289 new_lh = optimizeParametersOnly(i, gradient_epsilon, new_lh);
1290
1291 if (new_lh == 0.0) {
1292 if (fixed_len == BRLEN_OPTIMIZE)
1293 cur_lh = tree->optimizeAllBranches(tree->params->num_param_iterations, logl_epsilon);
1294 else if (fixed_len == BRLEN_SCALE) {
1295 double scaling = 1.0;
1296 cur_lh = tree->optimizeTreeLengthScaling(MIN_BRLEN_SCALE, scaling, MAX_BRLEN_SCALE, gradient_epsilon);
1297 }
1298 break;
1299 }
1300 if (verbose_mode >= VB_MED) {
1301 model->writeInfo(cout);
1302 site_rate->writeInfo(cout);
1303 if (fixed_len == BRLEN_SCALE)
1304 cout << "Scaled tree length: " << tree->treeLength() << endl;
1305 }
1306 if (new_lh > cur_lh + logl_epsilon) {
1307 cur_lh = new_lh;
1308 if (write_info)
1309 cout << i << ". Current log-likelihood: " << cur_lh << endl;
1310 } else {
1311 site_rate->classifyRates(new_lh);
1312 if (fixed_len == BRLEN_OPTIMIZE)
1313 cur_lh = tree->optimizeAllBranches(100, logl_epsilon);
1314 else if (fixed_len == BRLEN_SCALE) {
1315 double scaling = 1.0;
1316 cur_lh = tree->optimizeTreeLengthScaling(MIN_BRLEN_SCALE, scaling, MAX_BRLEN_SCALE, gradient_epsilon);
1317 }
1318 break;
1319 }
1320 }
1321
1322 // normalize rates s.t. branch lengths are #subst per site
1323 // if (Params::getInstance().optimize_alg_gammai != "EM")
1324 {
1325 double mean_rate = site_rate->rescaleRates();
1326 if (fabs(mean_rate-1.0) > 1e-6) {
1327 if (fixed_len == BRLEN_FIX)
1328 outError("Fixing branch lengths not supported under specified site rate model");
1329 tree->scaleLength(mean_rate);
1330 tree->clearAllPartialLH();
1331 }
1332 }
1333
1334 if (Params::getInstance().root_find && tree->rooted && Params::getInstance().root_move_dist > 0) {
1335 cur_lh = tree->optimizeRootPosition(Params::getInstance().root_move_dist, write_info, logl_epsilon);
1336 if (verbose_mode >= VB_MED || write_info)
1337 cout << "Rooting log-likelihood: " << cur_lh << endl;
1338 }
1339
1340 if (verbose_mode >= VB_MED || write_info)
1341 cout << "Optimal log-likelihood: " << cur_lh << endl;
1342
1343 // For UpperBounds -----------
1344 if(tree->mlCheck == 0)
1345 tree->mlFirstOpt = cur_lh;
1346 // ---------------------------
1347
1348 if (verbose_mode <= VB_MIN && write_info) {
1349 model->writeInfo(cout);
1350 site_rate->writeInfo(cout);
1351 if (fixed_len == BRLEN_SCALE)
1352 cout << "Scaled tree length: " << tree->treeLength() << endl;
1353 }
1354 double elapsed_secs = getRealTime() - begin_time;
1355 if (write_info)
1356 cout << "Parameters optimization took " << i-1 << " rounds (" << elapsed_secs << " sec)" << endl;
1357 startStoringTransMatrix();
1358
1359 // For UpperBounds -----------
1360 tree->mlCheck = 1;
1361 // ---------------------------
1362
1363 tree->setCurScore(cur_lh);
1364 return cur_lh;
1365 }
1366
1367 /**
1368 * @return TRUE if parameters are at the boundary that may cause numerical unstability
1369 */
isUnstableParameters()1370 bool ModelFactory::isUnstableParameters() {
1371 if (model->isUnstableParameters()) return true;
1372 return false;
1373 }
1374
startStoringTransMatrix()1375 void ModelFactory::startStoringTransMatrix() {
1376 if (!store_trans_matrix) return;
1377 is_storing = true;
1378 }
1379
stopStoringTransMatrix()1380 void ModelFactory::stopStoringTransMatrix() {
1381 if (!store_trans_matrix) return;
1382 is_storing = false;
1383 if (!empty()) {
1384 for (iterator it = begin(); it != end(); it++)
1385 delete it->second;
1386 clear();
1387 }
1388 }
1389
1390
computeTrans(double time,int state1,int state2)1391 double ModelFactory::computeTrans(double time, int state1, int state2) {
1392 return model->computeTrans(time, state1, state2);
1393 }
1394
computeTrans(double time,int state1,int state2,double & derv1,double & derv2)1395 double ModelFactory::computeTrans(double time, int state1, int state2, double &derv1, double &derv2) {
1396 return model->computeTrans(time, state1, state2, derv1, derv2);
1397 }
1398
computeTransMatrix(double time,double * trans_matrix,int mixture)1399 void ModelFactory::computeTransMatrix(double time, double *trans_matrix, int mixture) {
1400 if (!store_trans_matrix || !is_storing || model->isSiteSpecificModel()) {
1401 model->computeTransMatrix(time, trans_matrix, mixture);
1402 return;
1403 }
1404 int mat_size = model->num_states * model->num_states;
1405 iterator ass_it = find(round(time * 1e6));
1406 if (ass_it == end()) {
1407 // allocate memory for 3 matricies
1408 double *trans_entry = new double[mat_size * 3];
1409 trans_entry[mat_size] = trans_entry[mat_size+1] = 0.0;
1410 model->computeTransMatrix(time, trans_entry, mixture);
1411 ass_it = insert(value_type(round(time * 1e6), trans_entry)).first;
1412 } else {
1413 //if (verbose_mode >= VB_MAX)
1414 //cout << "ModelFactory bingo" << endl;
1415 }
1416
1417 memcpy(trans_matrix, ass_it->second, mat_size * sizeof(double));
1418 }
1419
computeTransDerv(double time,double * trans_matrix,double * trans_derv1,double * trans_derv2,int mixture)1420 void ModelFactory::computeTransDerv(double time, double *trans_matrix,
1421 double *trans_derv1, double *trans_derv2, int mixture) {
1422 if (!store_trans_matrix || !is_storing || model->isSiteSpecificModel()) {
1423 model->computeTransDerv(time, trans_matrix, trans_derv1, trans_derv2, mixture);
1424 return;
1425 }
1426 int mat_size = model->num_states * model->num_states;
1427 iterator ass_it = find(round(time * 1e6));
1428 if (ass_it == end()) {
1429 // allocate memory for 3 matricies
1430 double *trans_entry = new double[mat_size * 3];
1431 trans_entry[mat_size] = trans_entry[mat_size+1] = 0.0;
1432 model->computeTransDerv(time, trans_entry, trans_entry+mat_size, trans_entry+(mat_size*2), mixture);
1433 ass_it = insert(value_type(round(time * 1e6), trans_entry)).first;
1434 } else if (ass_it->second[mat_size] == 0.0 && ass_it->second[mat_size+1] == 0.0) {
1435 double *trans_entry = ass_it->second;
1436 model->computeTransDerv(time, trans_entry, trans_entry+mat_size, trans_entry+(mat_size*2), mixture);
1437 }
1438 memcpy(trans_matrix, ass_it->second, mat_size * sizeof(double));
1439 memcpy(trans_derv1, ass_it->second + mat_size, mat_size * sizeof(double));
1440 memcpy(trans_derv2, ass_it->second + (mat_size*2), mat_size * sizeof(double));
1441 }
1442
~ModelFactory()1443 ModelFactory::~ModelFactory()
1444 {
1445 for (iterator it = begin(); it != end(); it++)
1446 delete it->second;
1447 clear();
1448 }
1449
1450 /************* FOLLOWING SERVE FOR JOINT OPTIMIZATION OF MODEL AND RATE PARAMETERS *******/
getNDim()1451 int ModelFactory::getNDim()
1452 {
1453 return model->getNDim() + site_rate->getNDim();
1454 }
1455
targetFunk(double x[])1456 double ModelFactory::targetFunk(double x[]) {
1457 model->getVariables(x);
1458 // need to compute rates again if p_inv or Gamma shape changes!
1459 if (model->state_freq[model->num_states-1] < MIN_RATE) return 1.0e+12;
1460 model->decomposeRateMatrix();
1461 site_rate->phylo_tree->clearAllPartialLH();
1462 return site_rate->targetFunk(x + model->getNDim());
1463 }
1464
setVariables(double * variables)1465 void ModelFactory::setVariables(double *variables) {
1466 model->setVariables(variables);
1467 site_rate->setVariables(variables + model->getNDim());
1468 }
1469
getVariables(double * variables)1470 bool ModelFactory::getVariables(double *variables) {
1471 bool changed = model->getVariables(variables);
1472 changed |= site_rate->getVariables(variables + model->getNDim());
1473 return changed;
1474 }
1475