1 /*
2  * scrm is an implementation of the Sequential-Coalescent-with-Recombination Model.
3  *
4  * Copyright (C) 2013, 2014 Paul R. Staab, Sha (Joe) Zhu, Dirk Metzler and Gerton Lunter
5  *
6  * This file is part of scrm.
7  *
8  * scrm is free software: you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version.  * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 
17 */
18 
19 #include "forest.h"
20 
21 /******************************************************************
22  * Debugging Utils
23  *****************************************************************/
24 
createExampleTree()25 void Forest::createExampleTree() {
26   this->clear();
27   this->writable_model()->disable_approximation();
28   // Only set the number of samples to 4, but keep rest of the model
29   this->writable_model()->sample_times_.clear();
30   this->writable_model()->sample_populations_.clear();
31   this->writable_model()->addSampleSizes(0.0, std::vector<size_t>(1, 4));
32 
33   this->rec_bases_.push_back(5.0);
34   this->current_rec_ = 1;
35 
36   Node* leaf1 = nodes()->createNode(0, 1);
37   Node* leaf2 = nodes()->createNode(0, 2);
38   Node* leaf3 = nodes()->createNode(0, 3);
39   Node* leaf4 = nodes()->createNode(0, 4);
40 
41   leaf1->set_label(1);
42   leaf2->set_label(2);
43   leaf3->set_label(3);
44   leaf4->set_label(4);
45 
46   this->nodes()->add(leaf4);
47   this->nodes()->add(leaf3);
48   this->nodes()->add(leaf2);
49   this->nodes()->add(leaf1);
50 
51   Node* node12 = nodes()->createNode(1);
52   this->addNodeToTree(node12, NULL, leaf1, leaf2);
53 
54   Node* node34 = nodes()->createNode(3);
55   this->addNodeToTree(node34, NULL, leaf3, leaf4);
56 
57   Node* root = nodes()->createNode(10);
58   this->addNodeToTree(root, NULL, node12, node34);
59   this->set_local_root(root);
60   this->set_primary_root(root);
61 
62   // Add a non-local tree
63   Node* nl_node = nodes()->createNode(4);
64   nl_node->make_nonlocal(current_rec_);
65   Node* nl_root = nodes()->createNode(6);
66   nl_root->make_nonlocal(current_rec_);
67 
68   nl_node->set_parent(nl_root);
69   nl_root->set_first_child(nl_node);
70   this->nodes()->add(nl_node);
71   this->nodes()->add(nl_root);
72   updateAbove(nl_node);
73 
74   updateAbove(leaf1);
75   updateAbove(leaf2);
76   updateAbove(leaf3);
77   updateAbove(leaf4);
78 
79   this->set_sample_size(4);
80 
81   this->contemporaries_ = ContemporariesContainer(model().population_number(),
82                                                   model().sample_size(),
83                                                   random_generator());
84   this->tmp_event_time_ = -1;
85   this->coalescence_finished_ = true;
86 
87   assert( this->checkTreeLength() );
88   assert( this->checkTree() );
89 }
90 
createScaledExampleTree()91 void Forest::createScaledExampleTree() {
92   this->createExampleTree();
93 
94   this->nodes()->at(4)->set_height(1 * 4 * model().default_pop_size());
95   this->nodes()->at(5)->set_height(3 * 4 * model().default_pop_size());
96   this->nodes()->at(6)->set_height(4 * 4 * model().default_pop_size());
97   this->nodes()->at(7)->set_height(6 * 4 * model().default_pop_size());
98   this->nodes()->at(8)->set_height(10 * 4 * model().default_pop_size());
99 
100   updateAbove(nodes()->at(4));
101   updateAbove(nodes()->at(5));
102   updateAbove(nodes()->at(6));
103 
104   assert( this->checkTreeLength() );
105   assert( this->checkTree() );
106 }
107 
calcTreeLength() const108 double Forest::calcTreeLength() const {
109   double local_length = 0;
110 
111   for (ConstNodeIterator it = getNodes()->iterator(); it.good(); ++it) {
112     if ( *it == local_root() ) return local_length;
113     if ( (*it)->is_root() || !(*it)->local() ) continue;
114     local_length += (*it)->height_above();
115   }
116 
117   return local_length;
118 }
119 
120 
addNodeToTree(Node * node,Node * parent,Node * first_child,Node * second_child)121 void Forest::addNodeToTree(Node *node, Node *parent, Node *first_child, Node *second_child) {
122   this->nodes()->add(node);
123 
124   if (parent != NULL) {
125     node->set_parent(parent);
126     if (parent->first_child() == NULL) parent->set_first_child(node);
127     else {
128       if (parent->first_child()->height() > node->height()) {
129         parent->set_second_child(parent->first_child());
130         parent->set_first_child(node);
131       } else {
132         parent->set_second_child(node);
133       }
134     }
135   }
136 
137   if (first_child != NULL) {
138     node->set_first_child(first_child);
139     first_child->set_parent(node);
140   }
141 
142   if (second_child != NULL) {
143     node->set_second_child(second_child);
144     second_child->set_parent(node);
145   }
146 }
147 
148 
checkTreeLength() const149 bool Forest::checkTreeLength() const {
150   double local_length = calcTreeLength();
151 
152   if ( !areSame(local_length, getLocalTreeLength(), 0.000001) ) {
153     dout << "Error: local tree length is " << this->getLocalTreeLength() << " ";
154     dout << "but should be " << local_length << std::endl;
155     return(0);
156   }
157 
158   return(1);
159 }
160 
161 
checkInvariants(Node const * node) const162 bool Forest::checkInvariants(Node const* node) const {
163   if (node == NULL) {
164     bool okay = 1;
165 
166     for (ConstNodeIterator it = getNodes()->iterator(); it.good(); ++it) {
167       if ( (*it)->height() >= local_root()->height()) {
168         if (!(*it)->local()) continue;
169         dout << "Node " << *it << " is above the local root and local!" << std::endl;
170         okay = 0;
171       } else {
172         okay *= checkInvariants(*it);
173       }
174     }
175     return(okay);
176   }
177 
178   size_t samples_below = node->in_sample();
179   double length_below = 0;
180 
181   if (node->first_child() != NULL) {
182     samples_below += node->first_child()->samples_below();
183     length_below += node->first_child()->length_below();
184     if (node->first_child()->local())
185       length_below += node->first_child()->height_above();
186   }
187 
188   if (node->second_child() != NULL) {
189     samples_below += node->second_child()->samples_below();
190     length_below += node->second_child()->length_below();
191     if (node->second_child()->local())
192       length_below += node->second_child()->height_above();
193   }
194 
195   if ( samples_below != node->samples_below() ||
196       !areSame(length_below, node->length_below(), 0.00001) ) {
197     dout << "Node " << node << " not up to date" << std::endl;
198     dout << "samples_below: is " << node->samples_below()
199          << " and should be " << samples_below << std::endl;
200     dout << "length_below: is " << node->length_below()
201          << " and should be " << length_below
202          << " ( Diff " << node->length_below() - length_below << " )" << std::endl;
203 
204     printNodes();
205     printTree();
206     return false;
207   }
208 
209   if ( (samples_below == 0 || samples_below == sample_size()) && node->local() ) {
210     dout << "Node " << node << " is local but should be non-local" << std::endl;
211     return false;
212   }
213 
214   return true;
215 }
216 
217 
checkLeafsOnLocalTree(Node const * node) const218 bool Forest::checkLeafsOnLocalTree(Node const* node) const {
219   if (node == NULL) {
220     size_t all_on_tree = 1;
221     bool on_tree = 0;
222     for (ConstNodeIterator it = getNodes()->iterator(); it.good(); ++it) {
223       if ( !(*it)->in_sample() ) continue;
224       on_tree = checkLeafsOnLocalTree(*it);
225       if (!on_tree) dout << "Leaf " << *it << " is not on local tree!" << std::endl;
226       all_on_tree *= on_tree;
227     }
228     return(all_on_tree);
229   }
230   if ( node->local() ) return( checkLeafsOnLocalTree(node->parent()) );
231   return( node == this->local_root() );
232 }
233 
234 
checkNodeProperties() const235 bool Forest::checkNodeProperties() const {
236   bool success = true;
237   for (ConstNodeIterator it = getNodes()->iterator(); it.good(); ++it) {
238     if ( !(*it)->local() ) {
239       if ( (*it)->last_update() == 0 && !(*it)->is_root() ) {
240         dout << "Error: Node " << *it << " non-local without update info" << std::endl;
241         success = false;
242       }
243     }
244   }
245   return success;
246 }
247 
248 
checkTree(Node const * root) const249 bool Forest::checkTree(Node const* root) const {
250   if (root == NULL) {
251     bool good = true;
252     // Default when called without argument
253     for (ConstNodeIterator it = getNodes()->iterator(); it.good(); ++it) {
254       if ( (*it)->is_root() ) good *= checkTree(*it);
255     }
256 
257     good *= this->checkInvariants();
258     good *= this->checkNodeProperties();
259     good *= this->checkTreeLength();
260     good *= this->checkRoots();
261     return good;
262   }
263   assert( root != NULL );
264 
265   Node* h_child = root->second_child();
266   Node* l_child = root->first_child();
267 
268   bool child1 = 1;
269   if (h_child != NULL) {
270     if (l_child == NULL) {
271       dout << root << ": only child is second child" << std::endl;
272       return 0;
273     }
274     if (h_child->parent() != root) {
275       dout << h_child << ": is child of non-parent" << std::endl;
276       return 0;
277     }
278     if (h_child->height() > root->height()) {
279       dout << root << ": has child with greater height" << std::endl;
280       return 0;
281     }
282     if (h_child->population() != root->population()) {
283       dout << root << ": has child of other population" << std::endl;
284       return 0;
285     }
286     if (l_child->population() != root->population()) {
287       dout << root << ": has child of other population" << std::endl;
288       return 0;
289     }
290     child1 = checkTree(h_child);
291   }
292 
293   bool child2 = 1;
294   if (l_child != NULL) {
295     if (l_child->parent() != root) {
296       dout << l_child << ": is child of non-parent" << std::endl;
297       return 0;
298     }
299     child2 = checkTree(l_child);
300 
301     if (l_child->height() > root->height()) {
302       dout << root << ": has child with greater height" << std::endl;
303       return 0;
304     }
305   }
306 
307   // Check that parent if above node
308   if (!root->is_root()) {
309     Node* parent = root->parent();
310     Node const* current = root;
311     while (current != parent) {
312       if (current->is_last()) {
313         dout << root << ": node is above it's parent.";
314         return 0;
315       }
316       current = current->next();
317     }
318   }
319 
320   return child1 && child2;
321 }
322 
323 
324 
325 
326 /******************************************************************
327  * Tree Printing
328  *****************************************************************/
printTree() const329 bool Forest::printTree() const {
330   //this->printNodes();
331   std::vector<Node const*> positions = this->determinePositions();
332   //this->printPositions(positions);
333   std::vector<Node const*>::iterator position;
334   int h_line;
335   double start_height = 0,
336          end_height = getNodes()->get(0)->height();
337 
338   for (ConstNodeIterator ni = getNodes()->iterator(); ni.good(); ) {
339     if ( !(*ni)->is_root() && (*ni)->height_above() == 0.0 ) {
340       std::cout << "A rare situation occurred were a parent and a child have exactly "
341            << "the same height. We can't print such trees here, the algorithm however"
342            << "should not be affected." << std::endl;
343       return 1;
344     }
345     h_line = 0;
346     start_height = end_height;
347     while ( ni.height() <= end_height ) ++ni;
348     end_height = ni.height();
349     //std::cout << start_height << " - " << end_height << std::endl;
350 
351     for (position = positions.begin(); position != positions.end(); ++position) {
352       assert( *position != NULL );
353       if ( (*position)->height() == start_height ) {
354         if ( (*position)->local() || *position == local_root() ) std::cout << "╦";
355         else std::cout << "┬";
356         if ( (*position)->countChildren() == 2 ) {
357           h_line = 1 + !((*position)->local());
358           if ( *position == local_root() ) h_line = 1;
359         }
360         if ( (*position)->countChildren() == 1 ) {
361           h_line = 0;
362         }
363       }
364       else if ( (*position)->height() < start_height &&
365                 (*position)->parent_height() >= end_height ) {
366         if ( (*position)->local() ) std::cout << "║";
367         else std::cout << "│";
368 
369       }
370       else if ( (*position)->parent_height() == start_height ) {
371         if ( *position == (*position)->parent()->first_child() ) {
372           if ( (*position)->local() ) {
373             std::cout << "╚";
374             h_line = 1;
375           }
376           else {
377             std::cout << "└";
378             h_line = 2;
379           }
380         }
381         else {
382           if ( (*position)->local() ) std::cout << "╝";
383           else std::cout << "┘";
384           h_line = 0;
385         }
386       }
387       else {
388         if ( h_line == 0 ) std::cout << " ";
389         else if ( h_line == 1 ) std::cout << "═";
390         else std::cout << "─";
391       }
392     }
393     std::cout << " - " << std::setw(7) << std::setprecision(7) << std::right << start_height << " - ";
394     for (position = positions.begin(); position != positions.end(); ++position) {
395       if (*position == NULL) continue;
396       if ( (*position)->height() == start_height ) {
397         if ((*position)->label() != 0) std::cout << (*position)->label() << ":";
398         if (!(*position)->is_migrating()) std::cout << *position << "(" << (*position)->population() << ") ";
399         else std::cout << *position << "(" << (*position)->first_child()->population()
400                   << "->" << (*position)->population() << ") ";
401         if (nodeIsOld(*position)) std::cout << "old ";
402       }
403     }
404     std::cout << std::endl;
405   }
406   return true;
407 }
408 
409 /**
410  *  For printing the tree, each node gets assigned its own column in the printed area,
411  *  referred to as its positions. This function determines the position for all
412  *  nodes and returns the nodes in a vector sorted by position.
413  *
414  *  \return Vector of all nodes, sorted by position
415  */
determinePositions() const416 std::vector<Node const*> Forest::determinePositions() const {
417   std::vector<Node const*> positions(this->getNodes()->size(), NULL);
418 
419   ReverseConstNodeIterator it;
420   std::vector<const Node*>::iterator cit;
421   size_t lines_left, lines_right, position, root_offset = 0;
422   Node const* current_node;
423 
424   for (it = getNodes()->reverse_iterator(); it.good(); ++it) {
425     current_node = *it;
426 
427     lines_left = countLinesLeft(current_node);
428     lines_right = countLinesRight(current_node);
429 
430     if ( current_node->is_root() ) {
431       // Add root to the right of all current trees
432       position = countBelowLinesLeft(current_node->first_child()) + lines_left + root_offset;
433       //std::cout << current_node << " " << position << " " << lines_left << " "
434       //          << lines_right << " "
435       //          << countBelowLinesLeft(current_node->first_child()) << std::endl;
436 
437       root_offset = position +
438                     countBelowLinesRight(current_node->second_child()) +
439                     lines_right + 1;
440 
441       assert( positions[position] == NULL );
442       positions[position] = current_node;
443     } else {
444       // Get the position of the node (which was assigned when looking at its
445       // parent
446       position = 0;
447       for (cit = positions.begin(); cit < positions.end(); ++cit) {
448         if ( *cit == current_node ) break;
449         ++position;
450       }
451     }
452 
453     // Insert the child/children into branches
454     if (current_node->first_child() != NULL) {
455         assert( positions.at(position - lines_left) == NULL );
456         positions[position - lines_left] =  current_node->first_child();
457     }
458 
459 
460     if (current_node->second_child() != NULL) {
461         assert( positions.at(position + lines_right) == NULL );
462         positions[position + lines_right] = current_node->second_child();
463     }
464 
465   }
466   return positions;
467 }
468 
printPositions(const std::vector<Node const * > & positions) const469   void Forest::printPositions(const std::vector<Node const*> &positions) const {
470       for (size_t col = 0; col < positions.size() ; ++col) {
471         std::cout << positions[col] << " ";
472       }
473       std::cout << std::endl;
474   }
475 
countLinesLeft(Node const * node) const476   int Forest::countLinesLeft(Node const* node) const {
477     if ( node->first_child() == NULL ) return 0;
478     //if ( node->second_child() == NULL ) return 1;
479     return ( 1 + countBelowLinesRight(node->first_child()) );
480   }
481 
countLinesRight(Node const * node) const482   int Forest::countLinesRight(Node const* node) const {
483     if ( node->first_child() == NULL ) return 0;
484     if ( node->second_child() == NULL ) return 0;
485     return ( 1 + countBelowLinesLeft(node->second_child()) );
486   }
487 
countBelowLinesLeft(Node const * node) const488   int Forest::countBelowLinesLeft(Node const* node) const {
489     if ( node == NULL ) return 0;
490     if ( node->first_child() == NULL ) return 0;
491     else return ( countLinesLeft(node) + countBelowLinesLeft(node->first_child()) );
492   }
493 
countBelowLinesRight(Node const * node) const494   int Forest::countBelowLinesRight(Node const* node) const {
495     if ( node == NULL ) return 0;
496     if ( node->second_child() == NULL ) return 0;
497     else return ( countLinesRight(node) + countBelowLinesRight(node->second_child()) );
498   }
499 
printNodes() const500   bool Forest::printNodes() const {
501     std::cout << std::setw(15) << std::right << "Node";
502     std::cout << std::setw(15) << std::right << "Height";
503     std::cout << std::setw(6) << std::right << "label";
504     std::cout << std::setw(15) << std::right << "Parent";
505     std::cout << std::setw(15) << std::right << "1th_child";
506     std::cout << std::setw(15) << std::right << "2nd_child";
507     std::cout << std::setw(6) << std::right << "local";
508     std::cout << std::setw(6) << std::right << "pop";
509     std::cout << std::setw(10) << std::right << "l_upd";
510     std::cout << std::setw(6) << std::right << "s_bel";
511     std::cout << std::setw(10) << std::right << "l_bel";
512     std::cout << std::endl;
513 
514     for(size_t i = 0; i < this->getNodes()->size(); ++i) {
515       std::cout << std::setw(15) << std::right << this->getNodes()->get(i);
516       std::cout << std::setw(15) << std::right << this->getNodes()->get(i)->height();
517       std::cout << std::setw(6) << std::right << this->getNodes()->get(i)->label();
518       if (!getNodes()->get(i)->is_root())
519         std::cout << std::setw(15) << std::right << this->getNodes()->get(i)->parent();
520       else std::cout << std::setw(15) << std::right << 0;
521       std::cout << std::setw(15) << std::right << this->getNodes()->get(i)->first_child();
522       std::cout << std::setw(15) << std::right << this->getNodes()->get(i)->second_child();
523       std::cout << std::setw(6) << std::right << this->getNodes()->get(i)->local();
524       std::cout << std::setw(6) << std::right << this->getNodes()->get(i)->population();
525       std::cout << std::setw(10) << std::right << this->getNodes()->get(i)->last_update();
526       std::cout << std::setw(6) << std::right << this->getNodes()->get(i)->samples_below();
527       std::cout << std::setw(10) << std::right << this->getNodes()->get(i)->length_below();
528       std::cout << std::endl;
529     }
530     std::cout << "Local Root:    " << this->local_root() << std::endl;
531     std::cout << "Primary Root:  " << this->primary_root() << std::endl;
532     return true;
533   }
534 
535 
checkForNodeAtHeight(const double height) const536 bool Forest::checkForNodeAtHeight(const double height) const {
537   for (auto it = getNodes()->iterator(); it.good(); ++it) {
538     if ((*it)->height() == height) return true;
539     if ((*it)->height() > height) return false;
540   }
541   return false;
542 }
543 
544 // Checks if all nodes in contemporaries are contemporaries.
checkContemporaries(const double time) const545 bool Forest::checkContemporaries(const double time) const {
546   // Check if all nodes in contemporaries() are contemporaries
547   for (size_t pop = 0; pop < model().population_number(); ++pop) {
548     for (auto it = contemporaries_.begin(pop); it != contemporaries_.end(pop); ++it) {
549       if ( *it == NULL ) {
550         dout << "NULL in contemporaries" << std::endl;
551         return 0;
552       }
553 
554       if ( (*it)->is_root() ) {
555         dout << "Root " << *it << " in contemporaries" << std::endl;
556         return 0;
557       }
558 
559       if ( (*it)->height() > time || (*it)->parent_height() <= time ) {
560         dout << "Non-contemporary node " << *it << " in contemporaries "
561              << "at time " << time << " (node at " << (*it)->height()
562              << "; parent at " << (*it)->parent_height() << ")." << std::endl;
563         printNodes();
564         return 0;
565       }
566 
567       if ( nodeIsOld(*it) ) {
568         if ( *it == local_root() ) {
569           if ( !(*it)->is_root() ) {
570             dout << "Branch above local root should be pruned but is not" << std::endl;
571             return 0;
572           }
573         } else {
574           dout << "Contemporary node " << *it << " should be pruned by now!" << std::endl;
575           return 0;
576         }
577       }
578 
579       for (size_t i = 0; i < 2; ++i) {
580         if ( *it == active_node(i) && states_[i] == 1 ) {
581           dout << "Coalescing node a" << i << " in contemporaries!" << std::endl;
582           return 0;
583         }
584       }
585     }
586   }
587 
588   // Check if all contemporaries are in contemporaries()
589   for (auto ni = getNodes()->iterator(); ni.good(); ++ni) {
590     if ( (*ni)->height() <= time && time < (*ni)->parent_height()) {
591       if ( *ni == active_node(0) && states_[0] == 1 ) continue;
592       if ( *ni == active_node(1) && states_[1] == 1 ) continue;
593 
594       bool found = false;
595       size_t pop = (*ni)->population();
596       for (auto it = contemporaries_.begin(pop); it != contemporaries_.end(pop); ++it) {
597         if ( *it == *ni ) {
598           found = true;
599           break;
600         }
601       }
602       if (!found) {
603         dout << "Node " << *ni << " (height " << (*ni)->height()
604              << ") not in contemporaries at time " << time << std::endl;
605         return 0;
606       }
607     }
608   }
609 
610   return 1;
611 }
612 
checkRoots() const613 bool Forest::checkRoots() const {
614   // Check that local_root() really is the local root:
615   if (local_root()->samples_below() != sample_size() ||
616       local_root()->first_child() == NULL ||
617       local_root()->second_child() == NULL ||
618       (!local_root()->first_child()->local()) ||
619       (!local_root()->second_child()->local()) ) {
620     dout << local_root() << " is registered as local root, but is not." << std::endl;
621     return false;
622   }
623 
624   // Check that primary_root() really is the primary root:
625   Node* node = local_root();
626   while (!node->is_root()) node = node->parent();
627   if (node != primary_root()) {
628     dout << primary_root() << " is registered as primary root, but "
629          << node << " is." << std::endl;
630     return false;
631   }
632 
633   return true;
634 }
635