1 /*\t
2
3 Copyright (c) by respective owners including Yahoo!, Microsoft, and
4 individual contributors. All rights reserved. Released under a BSD (revised)
5 license as described in the file LICENSE.node
6 */
7 #include <float.h>
8 #include <math.h>
9 #include <stdio.h>
10 #include <sstream>
11
12 #include "reductions.h"
13
14 using namespace std;
15 using namespace LEARNER;
16
17 class node_pred
18 {
19 public:
20
21 double Ehk;
22 float norm_Ehk;
23 uint32_t nk;
24 uint32_t label;
25 uint32_t label_count;
26
operator ==(node_pred v)27 bool operator==(node_pred v){
28 return (label == v.label);
29 }
30
operator >(node_pred v)31 bool operator>(node_pred v){
32 if(label > v.label) return true;
33 return false;
34 }
35
operator <(node_pred v)36 bool operator<(node_pred v){
37 if(label < v.label) return true;
38 return false;
39 }
40
node_pred(uint32_t l)41 node_pred(uint32_t l)
42 {
43 label = l;
44 Ehk = 0.f;
45 norm_Ehk = 0;
46 nk = 0;
47 label_count = 0;
48 }
49 };
50
51 typedef struct
52 {//everyone has
53 uint32_t parent;//the parent node
54 v_array<node_pred> preds;//per-class state
55 uint32_t min_count;//the number of examples reaching this node (if it's a leaf) or the minimum reaching any grandchild.
56
57 bool internal;//internal or leaf
58
59 //internal nodes have
60 uint32_t base_predictor;//id of the base predictor
61 uint32_t left;//left child
62 uint32_t right;//right child
63 float norm_Eh;//the average margin at the node
64 double Eh;//total margin at the node
65 uint32_t n;//total events at the node
66
67 //leaf has
68 uint32_t max_count;//the number of samples of the most common label
69 uint32_t max_count_label;//the most common label
70 } node;
71
72 struct log_multi
73 {
74 uint32_t k;
75
76 v_array<node> nodes;
77
78 size_t max_predictors;
79 size_t predictors_used;
80
81 bool progress;
82 uint32_t swap_resist;
83
84 uint32_t nbofswaps;
85 };
86
init_leaf(node & n)87 inline void init_leaf(node& n)
88 {
89 n.internal = false;
90 n.preds.erase();
91 n.base_predictor = 0;
92 n.norm_Eh = 0;
93 n.Eh = 0;
94 n.n = 0;
95 n.max_count = 0;
96 n.max_count_label = 1;
97 n.left = 0;
98 n.right = 0;
99 }
100
init_node()101 inline node init_node()
102 {
103 node node;
104
105 node.parent = 0;
106 node.min_count = 0;
107 node.preds = v_init<node_pred>();
108 init_leaf(node);
109
110 return node;
111 }
112
init_tree(log_multi & d)113 void init_tree(log_multi& d)
114 {
115 d.nodes.push_back(init_node());
116 d.nbofswaps = 0;
117 }
118
min_left_right(log_multi & b,node & n)119 inline uint32_t min_left_right(log_multi& b, node& n)
120 {
121 return min(b.nodes[n.left].min_count, b.nodes[n.right].min_count);
122 }
123
find_switch_node(log_multi & b)124 inline uint32_t find_switch_node(log_multi& b)
125 {
126 uint32_t node = 0;
127 while(b.nodes[node].internal)
128 if(b.nodes[b.nodes[node].left].min_count
129 < b.nodes[b.nodes[node].right].min_count)
130 node = b.nodes[node].left;
131 else
132 node = b.nodes[node].right;
133 return node;
134 }
135
update_min_count(log_multi & b,uint32_t node)136 inline void update_min_count(log_multi& b, uint32_t node)
137 {//Constant time min count update.
138 while(node != 0)
139 {
140 uint32_t prev = node;
141 node = b.nodes[node].parent;
142
143 if (b.nodes[node].min_count == b.nodes[prev].min_count)
144 break;
145 else
146 b.nodes[node].min_count = min_left_right(b,b.nodes[node]);
147 }
148 }
149
display_tree_dfs(log_multi & b,node node,uint32_t depth)150 void display_tree_dfs(log_multi& b, node node, uint32_t depth)
151 {
152 for (uint32_t i = 0; i < depth; i++)
153 cout << "\t";
154 cout << node.min_count << " " << node.left
155 << " " << node.right;
156 cout << " label = " << node.max_count_label << " labels = ";
157 for (size_t i = 0; i < node.preds.size(); i++)
158 cout << node.preds[i].label << ":" << node.preds[i].label_count << "\t";
159 cout << endl;
160
161 if (node.internal)
162 {
163 cout << "Left";
164 display_tree_dfs(b, b.nodes[node.left], depth+1);
165
166 cout << "Right";
167 display_tree_dfs(b, b.nodes[node.right], depth+1);
168 }
169 }
170
children(log_multi & b,uint32_t & current,uint32_t & class_index,uint32_t label)171 bool children(log_multi& b, uint32_t& current, uint32_t& class_index, uint32_t label)
172 {
173 class_index = (uint32_t)b.nodes[current].preds.unique_add_sorted(node_pred(label));
174 b.nodes[current].preds[class_index].label_count++;
175
176 if(b.nodes[current].preds[class_index].label_count > b.nodes[current].max_count)
177 {
178 b.nodes[current].max_count = b.nodes[current].preds[class_index].label_count;
179 b.nodes[current].max_count_label = b.nodes[current].preds[class_index].label;
180 }
181
182 if (b.nodes[current].internal)
183 return true;
184 else if( b.nodes[current].preds.size() > 1
185 && (b.predictors_used < b.max_predictors
186 || b.nodes[current].min_count - b.nodes[current].max_count > b.swap_resist*(b.nodes[0].min_count + 1)))
187 { //need children and we can make them.
188 uint32_t left_child;
189 uint32_t right_child;
190 if (b.predictors_used < b.max_predictors)
191 {
192 left_child = (uint32_t)b.nodes.size();
193 b.nodes.push_back(init_node());
194 right_child = (uint32_t)b.nodes.size();
195 b.nodes.push_back(init_node());
196 b.nodes[current].base_predictor = (uint32_t)b.predictors_used++;
197 }
198 else
199 {
200 uint32_t swap_child = find_switch_node(b);
201 uint32_t swap_parent = b.nodes[swap_child].parent;
202 uint32_t swap_grandparent = b.nodes[swap_parent].parent;
203 if (b.nodes[swap_child].min_count != b.nodes[0].min_count)
204 cout << "glargh " << b.nodes[swap_child].min_count << " != " << b.nodes[0].min_count << endl;
205 b.nbofswaps++;
206
207 uint32_t nonswap_child;
208 if(swap_child == b.nodes[swap_parent].right)
209 nonswap_child = b.nodes[swap_parent].left;
210 else
211 nonswap_child = b.nodes[swap_parent].right;
212
213 if(swap_parent == b.nodes[swap_grandparent].left)
214 b.nodes[swap_grandparent].left = nonswap_child;
215 else
216 b.nodes[swap_grandparent].right = nonswap_child;
217 b.nodes[nonswap_child].parent = swap_grandparent;
218 update_min_count(b, nonswap_child);
219
220 init_leaf(b.nodes[swap_child]);
221 left_child = swap_child;
222 b.nodes[current].base_predictor = b.nodes[swap_parent].base_predictor;
223 init_leaf(b.nodes[swap_parent]);
224 right_child = swap_parent;
225 }
226 b.nodes[current].left = left_child;
227 b.nodes[left_child].parent = current;
228 b.nodes[current].right = right_child;
229 b.nodes[right_child].parent = current;
230
231 b.nodes[left_child].min_count = b.nodes[current].min_count/2;
232 b.nodes[right_child].min_count = b.nodes[current].min_count - b.nodes[left_child].min_count;
233 update_min_count(b, left_child);
234
235 b.nodes[left_child].max_count_label = b.nodes[current].max_count_label;
236 b.nodes[right_child].max_count_label = b.nodes[current].max_count_label;
237
238 b.nodes[current].internal = true;
239 }
240 return b.nodes[current].internal;
241 }
242
train_node(log_multi & b,base_learner & base,example & ec,uint32_t & current,uint32_t & class_index)243 void train_node(log_multi& b, base_learner& base, example& ec, uint32_t& current, uint32_t& class_index)
244 {
245 if(b.nodes[current].norm_Eh > b.nodes[current].preds[class_index].norm_Ehk)
246 ec.l.simple.label = -1.f;
247 else
248 ec.l.simple.label = 1.f;
249
250 base.learn(ec, b.nodes[current].base_predictor);
251
252 ec.l.simple.label = FLT_MAX;
253 base.predict(ec, b.nodes[current].base_predictor);
254
255 b.nodes[current].Eh += (double)ec.partial_prediction;
256 b.nodes[current].preds[class_index].Ehk += (double)ec.partial_prediction;
257 b.nodes[current].n++;
258 b.nodes[current].preds[class_index].nk++;
259
260 b.nodes[current].norm_Eh = (float)b.nodes[current].Eh / b.nodes[current].n;
261 b.nodes[current].preds[class_index].norm_Ehk = (float)b.nodes[current].preds[class_index].Ehk / b.nodes[current].preds[class_index].nk;
262 }
263
verify_min_dfs(log_multi & b,node node)264 void verify_min_dfs(log_multi& b, node node)
265 {
266 if (node.internal)
267 {
268 if (node.min_count != min_left_right(b, node))
269 {
270 cout << "badness! " << endl;
271 display_tree_dfs(b, b.nodes[0], 0);
272 }
273 verify_min_dfs(b, b.nodes[node.left]);
274 verify_min_dfs(b, b.nodes[node.right]);
275 }
276 }
277
sum_count_dfs(log_multi & b,node node)278 size_t sum_count_dfs(log_multi& b, node node)
279 {
280 if (node.internal)
281 return sum_count_dfs(b, b.nodes[node.left]) + sum_count_dfs(b, b.nodes[node.right]);
282 else
283 return node.min_count;
284 }
285
descend(node & n,float prediction)286 inline uint32_t descend(node& n, float prediction)
287 {
288 if (prediction < 0)
289 return n.left;
290 else
291 return n.right;
292 }
293
predict(log_multi & b,base_learner & base,example & ec)294 void predict(log_multi& b, base_learner& base, example& ec)
295 {
296 MULTICLASS::label_t mc = ec.l.multi;
297
298 ec.l.simple = {FLT_MAX, 0.f, 0.f};
299 uint32_t cn = 0;
300 while(b.nodes[cn].internal)
301 {
302 base.predict(ec, b.nodes[cn].base_predictor);
303 cn = descend(b.nodes[cn], ec.pred.scalar);
304 }
305 ec.pred.multiclass = b.nodes[cn].max_count_label;
306 ec.l.multi = mc;
307 }
308
learn(log_multi & b,base_learner & base,example & ec)309 void learn(log_multi& b, base_learner& base, example& ec)
310 {
311 // verify_min_dfs(b, b.nodes[0]);
312 if (ec.l.multi.label == (uint32_t)-1 || b.progress)
313 predict(b,base,ec);
314
315 if(ec.l.multi.label != (uint32_t)-1) //if training the tree
316 {
317 MULTICLASS::label_t mc = ec.l.multi;
318 uint32_t start_pred = ec.pred.multiclass;
319
320 uint32_t class_index = 0;
321 ec.l.simple = {FLT_MAX, mc.weight, 0.f};
322 uint32_t cn = 0;
323 while(children(b, cn, class_index, mc.label))
324 {
325 train_node(b, base, ec, cn, class_index);
326 cn = descend(b.nodes[cn], ec.pred.scalar);
327 }
328
329 b.nodes[cn].min_count++;
330 update_min_count(b, cn);
331 ec.pred.multiclass = start_pred;
332 ec.l.multi = mc;
333 }
334 }
335
save_node_stats(log_multi & d)336 void save_node_stats(log_multi& d)
337 {
338 FILE *fp;
339 uint32_t i, j;
340 uint32_t total;
341 log_multi* b = &d;
342
343 fp = fopen("atxm_debug.csv", "wt");
344
345 for(i = 0; i < b->nodes.size(); i++)
346 {
347 fprintf(fp, "Node: %4d, Internal: %1d, Eh: %7.4f, n: %6d, \n", (int) i, (int) b->nodes[i].internal, b->nodes[i].Eh / b->nodes[i].n, b->nodes[i].n);
348
349 fprintf(fp, "Label:, ");
350 for(j = 0; j < b->nodes[i].preds.size(); j++)
351 {
352 fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].label);
353 }
354 fprintf(fp, "\n");
355
356 fprintf(fp, "Ehk:, ");
357 for(j = 0; j < b->nodes[i].preds.size(); j++)
358 {
359 fprintf(fp, "%7.4f,", b->nodes[i].preds[j].Ehk / b->nodes[i].preds[j].nk);
360 }
361 fprintf(fp, "\n");
362
363 total = 0;
364
365 fprintf(fp, "nk:, ");
366 for(j = 0; j < b->nodes[i].preds.size(); j++)
367 {
368 fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].nk);
369 total += b->nodes[i].preds[j].nk;
370 }
371 fprintf(fp, "\n");
372
373 fprintf(fp, "max(lab:cnt:tot):, %3d,%6d,%7d,\n", (int) b->nodes[i].max_count_label, (int) b->nodes[i].max_count, (int) total);
374 fprintf(fp, "left: %4d, right: %4d", (int) b->nodes[i].left, (int) b->nodes[i].right);
375 fprintf(fp, "\n\n");
376 }
377
378 fclose(fp);
379 }
380
finish(log_multi & b)381 void finish(log_multi& b)
382 {
383 //save_node_stats(b);
384 }
385
save_load_tree(log_multi & b,io_buf & model_file,bool read,bool text)386 void save_load_tree(log_multi& b, io_buf& model_file, bool read, bool text)
387 {
388 if (model_file.files.size() > 0)
389 {
390 char buff[512];
391
392 uint32_t text_len = sprintf(buff, "k = %d ",b.k);
393 bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.k), "", read, buff, text_len, text);
394 uint32_t temp = (uint32_t)b.nodes.size();
395 text_len = sprintf(buff, "nodes = %d ",temp);
396 bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
397 if (read)
398 for (uint32_t j = 1; j < temp; j++)
399 b.nodes.push_back(init_node());
400 text_len = sprintf(buff, "max_predictors = %ld ",b.max_predictors);
401 bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text);
402
403 text_len = sprintf(buff, "predictors_used = %ld ",b.predictors_used);
404 bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text);
405
406 text_len = sprintf(buff, "progress = %d ",b.progress);
407 bin_text_read_write_fixed(model_file,(char*)&b.progress, sizeof(b.progress), "", read, buff, text_len, text);
408
409 text_len = sprintf(buff, "swap_resist = %d\n",b.swap_resist);
410 bin_text_read_write_fixed(model_file,(char*)&b.swap_resist, sizeof(b.swap_resist), "", read, buff, text_len, text);
411
412 for (size_t j = 0; j < b.nodes.size(); j++)
413 {//Need to read or write nodes.
414 node& n = b.nodes[j];
415 text_len = sprintf(buff, " parent = %d",n.parent);
416 bin_text_read_write_fixed(model_file,(char*)&n.parent, sizeof(n.parent), "", read, buff, text_len, text);
417
418 uint32_t temp = (uint32_t)n.preds.size();
419 text_len = sprintf(buff, " preds = %d",temp);
420 bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
421 if (read)
422 for (uint32_t k = 0; k < temp; k++)
423 n.preds.push_back(node_pred(1));
424
425 text_len = sprintf(buff, " min_count = %d",n.min_count);
426 bin_text_read_write_fixed(model_file,(char*)&n.min_count, sizeof(n.min_count), "", read, buff, text_len, text);
427
428 uint32_t text_len = sprintf(buff, " internal = %d",n.internal);
429 bin_text_read_write_fixed(model_file,(char*)&n.internal, sizeof(n.internal), "", read, buff, text_len, text)
430 ;
431
432 if (n.internal)
433 {
434 text_len = sprintf(buff, " base_predictor = %d",n.base_predictor);
435 bin_text_read_write_fixed(model_file,(char*)&n.base_predictor, sizeof(n.base_predictor), "", read, buff, text_len, text);
436
437 text_len = sprintf(buff, " left = %d",n.left);
438 bin_text_read_write_fixed(model_file,(char*)&n.left, sizeof(n.left), "", read, buff, text_len, text);
439
440 text_len = sprintf(buff, " right = %d",n.right);
441 bin_text_read_write_fixed(model_file,(char*)&n.right, sizeof(n.right), "", read, buff, text_len, text);
442
443 text_len = sprintf(buff, " norm_Eh = %f",n.norm_Eh);
444 bin_text_read_write_fixed(model_file,(char*)&n.norm_Eh, sizeof(n.norm_Eh), "", read, buff, text_len, text);
445
446 text_len = sprintf(buff, " Eh = %f",n.Eh);
447 bin_text_read_write_fixed(model_file,(char*)&n.Eh, sizeof(n.Eh), "", read, buff, text_len, text);
448
449 text_len = sprintf(buff, " n = %d\n",n.n);
450 bin_text_read_write_fixed(model_file,(char*)&n.n, sizeof(n.n), "", read, buff, text_len, text);
451 }
452 else
453 {
454 text_len = sprintf(buff, " max_count = %d",n.max_count);
455 bin_text_read_write_fixed(model_file,(char*)&n.max_count, sizeof(n.max_count), "", read, buff, text_len, text);
456 text_len = sprintf(buff, " max_count_label = %d\n",n.max_count_label);
457 bin_text_read_write_fixed(model_file,(char*)&n.max_count_label, sizeof(n.max_count_label), "", read, buff, text_len, text);
458 }
459
460 for (size_t k = 0; k < n.preds.size(); k++)
461 {
462 node_pred& p = n.preds[k];
463
464 text_len = sprintf(buff, " Ehk = %f",p.Ehk);
465 bin_text_read_write_fixed(model_file,(char*)&p.Ehk, sizeof(p.Ehk), "", read, buff, text_len, text);
466
467 text_len = sprintf(buff, " norm_Ehk = %f",p.norm_Ehk);
468 bin_text_read_write_fixed(model_file,(char*)&p.norm_Ehk, sizeof(p.norm_Ehk), "", read, buff, text_len, text);
469
470 text_len = sprintf(buff, " nk = %d",p.nk);
471 bin_text_read_write_fixed(model_file,(char*)&p.nk, sizeof(p.nk), "", read, buff, text_len, text);
472
473 text_len = sprintf(buff, " label = %d",p.label);
474 bin_text_read_write_fixed(model_file,(char*)&p.label, sizeof(p.label), "", read, buff, text_len, text);
475
476 text_len = sprintf(buff, " label_count = %d\n",p.label_count);
477 bin_text_read_write_fixed(model_file,(char*)&p.label_count, sizeof(p.label_count), "", read, buff, text_len, text);
478 }
479 }
480 }
481 }
482
log_multi_setup(vw & all)483 base_learner* log_multi_setup(vw& all) //learner setup
484 {
485 if (missing_option<size_t, true>(all, "log_multi", "Use online tree for multiclass"))
486 return NULL;
487 new_options(all, "Logarithmic Time Multiclass options")
488 ("no_progress", "disable progressive validation")
489 ("swap_resistance", po::value<uint32_t>(), "higher = more resistance to swap, default=4");
490 add_options(all);
491
492 po::variables_map& vm = all.vm;
493
494 log_multi& data = calloc_or_die<log_multi>();
495 data.k = (uint32_t)vm["log_multi"].as<size_t>();
496 data.swap_resist = 4;
497
498 if (vm.count("swap_resistance"))
499 data.swap_resist = vm["swap_resistance"].as<uint32_t>();
500
501 if (vm.count("no_progress"))
502 data.progress = false;
503 else
504 data.progress = true;
505
506 string loss_function = "quantile";
507 float loss_parameter = 0.5;
508 delete(all.loss);
509 all.loss = getLossFunction(all, loss_function, loss_parameter);
510
511 data.max_predictors = data.k - 1;
512 init_tree(data);
513
514 learner<log_multi>& l = init_multiclass_learner(&data, setup_base(all), learn, predict, all.p, data.max_predictors);
515 l.set_save_load(save_load_tree);
516 l.set_finish(finish);
517
518 return make_base(l);
519 }
520