1 #include "tree_utils.hh"
2 #include <stack>
3 #include <exception>
4 #include "../tree/node_tree.hh"
5
6 using namespace cpprofiler::tree;
7
8 namespace cpprofiler
9 {
10 namespace utils
11 {
12
count_descendants(const NodeTree & nt,NodeID nid)13 int count_descendants(const NodeTree &nt, NodeID nid)
14 {
15
16 /// NoNode has 0 descendants (as opposed to a leaf node, which counts itself)
17 if (nid == NodeID::NoNode)
18 return 0;
19
20 int count = 0;
21
22 auto fun = [&count](NodeID n) {
23 count++;
24 };
25
26 apply_below(nt, nid, fun);
27
28 return count;
29 }
30
calculate_depth(const NodeTree & nt,NodeID nid)31 int calculate_depth(const NodeTree &nt, NodeID nid)
32 {
33 int depth = 0;
34 while (nid != NodeID::NoNode)
35 {
36 nid = nt.getParent(nid);
37 ++depth;
38 }
39 return depth;
40 }
41
nodes_below(const tree::NodeTree & nt,NodeID nid)42 std::vector<NodeID> nodes_below(const tree::NodeTree &nt, NodeID nid)
43 {
44
45 if (nid == NodeID::NoNode)
46 {
47 throw std::exception();
48 }
49
50 std::vector<NodeID> nodes;
51
52 std::stack<NodeID> stk;
53
54 stk.push(nid);
55
56 while (!stk.empty())
57 {
58 const auto n = stk.top();
59 stk.pop();
60 nodes.push_back(n);
61
62 const auto kids = nt.childrenCount(n);
63 for (auto alt = 0; alt < kids; ++alt)
64 {
65 stk.push(nt.getChild(n, alt));
66 }
67 }
68
69 return nodes;
70 }
71
apply_below(const NodeTree & nt,NodeID nid,const NodeAction & action)72 void apply_below(const NodeTree &nt, NodeID nid, const NodeAction &action)
73 {
74
75 if (nid == NodeID::NoNode)
76 {
77 throw std::exception();
78 }
79
80 auto nodes = nodes_below(nt, nid);
81
82 for (auto n : nodes)
83 {
84 action(n);
85 }
86 }
87
pre_order_apply(const tree::NodeTree & nt,NodeID start,const NodeAction & action)88 void pre_order_apply(const tree::NodeTree &nt, NodeID start, const NodeAction &action)
89 {
90 std::stack<NodeID> stk;
91
92 stk.push(start);
93
94 while (stk.size() > 0)
95 {
96 auto nid = stk.top();
97 stk.pop();
98
99 action(nid);
100
101 for (auto i = nt.childrenCount(nid) - 1; i >= 0; --i)
102 {
103 auto child = nt.getChild(nid, i);
104 stk.push(child);
105 }
106 }
107 }
108
is_right_most_child(const tree::NodeTree & nt,NodeID nid)109 bool is_right_most_child(const tree::NodeTree &nt, NodeID nid)
110 {
111 const auto pid = nt.getParent(nid);
112
113 /// root is treated as the left-most child
114 if (pid == NodeID::NoNode)
115 return false;
116
117 const auto kids = nt.childrenCount(pid);
118 const auto alt = nt.getAlternative(nid);
119 return (alt == kids - 1);
120 }
121
pre_order(const NodeTree & tree)122 std::vector<NodeID> pre_order(const NodeTree &tree)
123 {
124 std::stack<NodeID> stk;
125 std::vector<NodeID> result;
126
127 NodeID root = NodeID{0};
128
129 stk.push(root);
130
131 while (stk.size() > 0)
132 {
133 auto nid = stk.top();
134 stk.pop();
135 result.push_back(nid);
136
137 for (auto i = tree.childrenCount(nid) - 1; i >= 0; --i)
138 {
139 auto child = tree.getChild(nid, i);
140 stk.push(child);
141 }
142 }
143
144 return result;
145 }
146
any_order(const NodeTree & tree)147 std::vector<NodeID> any_order(const NodeTree &tree)
148 {
149
150 auto count = tree.nodeCount();
151 std::vector<NodeID> result;
152 result.reserve(count);
153
154 for (auto i = 0; i < count; ++i)
155 {
156 result.push_back(NodeID(i));
157 }
158
159 return result;
160 }
161
post_order(const NodeTree & tree)162 std::vector<NodeID> post_order(const NodeTree &tree)
163 {
164 std::stack<NodeID> stk_1;
165 std::vector<NodeID> result;
166
167 result.reserve(tree.nodeCount());
168
169 auto root = tree.getRoot();
170
171 stk_1.push(root);
172
173 while (!stk_1.empty())
174 {
175 auto nid = stk_1.top();
176 stk_1.pop();
177
178 result.push_back(nid);
179
180 for (auto i = 0; i < tree.childrenCount(nid); ++i)
181 {
182 auto child = tree.getChild(nid, i);
183 stk_1.push(child);
184 }
185 }
186
187 std::reverse(result.begin(), result.end());
188
189 return result;
190 }
191
calc_subtree_sizes(const tree::NodeTree & nt)192 std::vector<int> calc_subtree_sizes(const tree::NodeTree &nt)
193 {
194 const int nc = nt.nodeCount();
195
196 std::vector<int> sizes(nc);
197
198 std::function<void(NodeID)> countDescendants;
199
200 /// Count descendants plus one (the node itself)
201 countDescendants = [&](NodeID n) {
202 auto nkids = nt.childrenCount(n);
203 if (nkids == 0)
204 {
205 sizes[n] = 1;
206 }
207 else
208 {
209 int count = 1; // the node itself
210 for (auto alt = 0u; alt < nkids; ++alt)
211 {
212 const auto kid = nt.getChild(n, alt);
213 countDescendants(kid);
214 count += sizes[kid];
215 }
216 sizes[n] = count;
217 }
218 };
219
220 const auto root = nt.getRoot();
221 countDescendants(root);
222
223 return sizes;
224 }
225
226 } // namespace utils
227 } // namespace cpprofiler