1 /*\t
2 
3 Copyright (c) by respective owners including Yahoo!, Microsoft, and
4 individual contributors. All rights reserved. Released under a BSD (revised)
5 license as described in the file LICENSE.node
6 */
7 #include <float.h>
8 #include <math.h>
9 #include <stdio.h>
10 #include <sstream>
11 
12 #include "reductions.h"
13 
14 using namespace std;
15 using namespace LEARNER;
16 
17   class node_pred
18   {
19   public:
20 
21     double Ehk;
22     float norm_Ehk;
23     uint32_t nk;
24     uint32_t label;
25     uint32_t label_count;
26 
operator ==(node_pred v)27     bool operator==(node_pred v){
28       return (label == v.label);
29     }
30 
operator >(node_pred v)31     bool operator>(node_pred v){
32       if(label > v.label) return true;
33       return false;
34     }
35 
operator <(node_pred v)36     bool operator<(node_pred v){
37       if(label < v.label) return true;
38       return false;
39     }
40 
node_pred(uint32_t l)41     node_pred(uint32_t l)
42     {
43       label = l;
44       Ehk = 0.f;
45       norm_Ehk = 0;
46       nk = 0;
47       label_count = 0;
48     }
49   };
50 
51   typedef struct
52   {//everyone has
53     uint32_t parent;//the parent node
54     v_array<node_pred> preds;//per-class state
55     uint32_t min_count;//the number of examples reaching this node (if it's a leaf) or the minimum reaching any grandchild.
56 
57     bool internal;//internal or leaf
58 
59     //internal nodes have
60     uint32_t base_predictor;//id of the base predictor
61     uint32_t left;//left child
62     uint32_t right;//right child
63     float norm_Eh;//the average margin at the node
64     double Eh;//total margin at the node
65     uint32_t n;//total events at the node
66 
67     //leaf has
68     uint32_t max_count;//the number of samples of the most common label
69     uint32_t max_count_label;//the most common label
70   } node;
71 
72   struct log_multi
73   {
74     uint32_t k;
75 
76     v_array<node> nodes;
77 
78     size_t max_predictors;
79     size_t predictors_used;
80 
81     bool progress;
82     uint32_t swap_resist;
83 
84     uint32_t nbofswaps;
85   };
86 
init_leaf(node & n)87   inline void init_leaf(node& n)
88   {
89     n.internal = false;
90     n.preds.erase();
91     n.base_predictor = 0;
92     n.norm_Eh = 0;
93     n.Eh = 0;
94     n.n = 0;
95     n.max_count = 0;
96     n.max_count_label = 1;
97     n.left = 0;
98     n.right = 0;
99   }
100 
init_node()101   inline node init_node()
102   {
103     node node;
104 
105     node.parent = 0;
106     node.min_count = 0;
107     node.preds = v_init<node_pred>();
108     init_leaf(node);
109 
110     return node;
111   }
112 
init_tree(log_multi & d)113   void init_tree(log_multi& d)
114   {
115     d.nodes.push_back(init_node());
116     d.nbofswaps = 0;
117   }
118 
min_left_right(log_multi & b,node & n)119   inline uint32_t min_left_right(log_multi& b, node& n)
120   {
121     return min(b.nodes[n.left].min_count, b.nodes[n.right].min_count);
122   }
123 
find_switch_node(log_multi & b)124   inline uint32_t find_switch_node(log_multi& b)
125   {
126     uint32_t node = 0;
127     while(b.nodes[node].internal)
128       if(b.nodes[b.nodes[node].left].min_count
129 	 < b.nodes[b.nodes[node].right].min_count)
130 	node = b.nodes[node].left;
131       else
132 	node = b.nodes[node].right;
133     return node;
134   }
135 
update_min_count(log_multi & b,uint32_t node)136   inline void update_min_count(log_multi& b, uint32_t node)
137   {//Constant time min count update.
138     while(node != 0)
139       {
140 	uint32_t prev = node;
141 	node = b.nodes[node].parent;
142 
143 	if (b.nodes[node].min_count == b.nodes[prev].min_count)
144 	  break;
145 	else
146 	  b.nodes[node].min_count = min_left_right(b,b.nodes[node]);
147       }
148   }
149 
display_tree_dfs(log_multi & b,node node,uint32_t depth)150   void display_tree_dfs(log_multi& b, node node, uint32_t depth)
151   {
152     for (uint32_t i = 0; i < depth; i++)
153       cout << "\t";
154     cout << node.min_count << " " << node.left
155 	 << " " << node.right;
156     cout << " label = " << node.max_count_label << " labels = ";
157     for (size_t i = 0; i < node.preds.size(); i++)
158       cout << node.preds[i].label << ":" << node.preds[i].label_count << "\t";
159     cout << endl;
160 
161     if (node.internal)
162       {
163 	cout << "Left";
164 	display_tree_dfs(b, b.nodes[node.left], depth+1);
165 
166 	cout << "Right";
167 	display_tree_dfs(b, b.nodes[node.right], depth+1);
168       }
169   }
170 
children(log_multi & b,uint32_t & current,uint32_t & class_index,uint32_t label)171   bool children(log_multi& b, uint32_t& current, uint32_t& class_index, uint32_t label)
172   {
173     class_index = (uint32_t)b.nodes[current].preds.unique_add_sorted(node_pred(label));
174     b.nodes[current].preds[class_index].label_count++;
175 
176     if(b.nodes[current].preds[class_index].label_count > b.nodes[current].max_count)
177       {
178 	b.nodes[current].max_count = b.nodes[current].preds[class_index].label_count;
179 	b.nodes[current].max_count_label = b.nodes[current].preds[class_index].label;
180       }
181 
182     if (b.nodes[current].internal)
183       return true;
184     else if( b.nodes[current].preds.size() > 1
185 	     && (b.predictors_used < b.max_predictors
186 				     || b.nodes[current].min_count - b.nodes[current].max_count > b.swap_resist*(b.nodes[0].min_count + 1)))
187       { //need children and we can make them.
188 	uint32_t left_child;
189 	uint32_t right_child;
190 	if (b.predictors_used < b.max_predictors)
191 	  {
192 	    left_child = (uint32_t)b.nodes.size();
193 	    b.nodes.push_back(init_node());
194 	    right_child = (uint32_t)b.nodes.size();
195 	    b.nodes.push_back(init_node());
196 	    b.nodes[current].base_predictor = (uint32_t)b.predictors_used++;
197 	  }
198 	else
199 	  {
200 	    uint32_t swap_child = find_switch_node(b);
201 	    uint32_t swap_parent = b.nodes[swap_child].parent;
202 	    uint32_t swap_grandparent = b.nodes[swap_parent].parent;
203 	    if (b.nodes[swap_child].min_count != b.nodes[0].min_count)
204 	      cout << "glargh " << b.nodes[swap_child].min_count << " != " << b.nodes[0].min_count << endl;
205 	    b.nbofswaps++;
206 
207 	    uint32_t nonswap_child;
208 	    if(swap_child == b.nodes[swap_parent].right)
209 	      nonswap_child = b.nodes[swap_parent].left;
210 	    else
211 	      nonswap_child = b.nodes[swap_parent].right;
212 
213 	    if(swap_parent == b.nodes[swap_grandparent].left)
214 	      b.nodes[swap_grandparent].left = nonswap_child;
215 	    else
216 	      b.nodes[swap_grandparent].right = nonswap_child;
217 	    b.nodes[nonswap_child].parent = swap_grandparent;
218 	    update_min_count(b, nonswap_child);
219 
220 	    init_leaf(b.nodes[swap_child]);
221 	    left_child = swap_child;
222 	    b.nodes[current].base_predictor = b.nodes[swap_parent].base_predictor;
223 	    init_leaf(b.nodes[swap_parent]);
224 	    right_child = swap_parent;
225 	  }
226 	b.nodes[current].left = left_child;
227 	b.nodes[left_child].parent = current;
228 	b.nodes[current].right = right_child;
229 	b.nodes[right_child].parent = current;
230 
231 	b.nodes[left_child].min_count = b.nodes[current].min_count/2;
232 	b.nodes[right_child].min_count = b.nodes[current].min_count - b.nodes[left_child].min_count;
233 	update_min_count(b, left_child);
234 
235 	b.nodes[left_child].max_count_label = b.nodes[current].max_count_label;
236 	b.nodes[right_child].max_count_label = b.nodes[current].max_count_label;
237 
238 	b.nodes[current].internal = true;
239       }
240     return b.nodes[current].internal;
241   }
242 
train_node(log_multi & b,base_learner & base,example & ec,uint32_t & current,uint32_t & class_index)243   void train_node(log_multi& b, base_learner& base, example& ec, uint32_t& current, uint32_t& class_index)
244   {
245     if(b.nodes[current].norm_Eh > b.nodes[current].preds[class_index].norm_Ehk)
246       ec.l.simple.label = -1.f;
247     else
248       ec.l.simple.label = 1.f;
249 
250     base.learn(ec, b.nodes[current].base_predictor);
251 
252     ec.l.simple.label = FLT_MAX;
253     base.predict(ec, b.nodes[current].base_predictor);
254 
255     b.nodes[current].Eh += (double)ec.partial_prediction;
256     b.nodes[current].preds[class_index].Ehk += (double)ec.partial_prediction;
257     b.nodes[current].n++;
258     b.nodes[current].preds[class_index].nk++;
259 
260     b.nodes[current].norm_Eh = (float)b.nodes[current].Eh / b.nodes[current].n;
261     b.nodes[current].preds[class_index].norm_Ehk = (float)b.nodes[current].preds[class_index].Ehk / b.nodes[current].preds[class_index].nk;
262   }
263 
verify_min_dfs(log_multi & b,node node)264   void verify_min_dfs(log_multi& b, node node)
265   {
266     if (node.internal)
267       {
268 	if (node.min_count != min_left_right(b, node))
269 	  {
270 	    cout << "badness! " << endl;
271 	    display_tree_dfs(b, b.nodes[0], 0);
272 	  }
273 	verify_min_dfs(b, b.nodes[node.left]);
274 	verify_min_dfs(b, b.nodes[node.right]);
275       }
276   }
277 
sum_count_dfs(log_multi & b,node node)278   size_t sum_count_dfs(log_multi& b, node node)
279   {
280     if (node.internal)
281       return sum_count_dfs(b, b.nodes[node.left]) + sum_count_dfs(b, b.nodes[node.right]);
282     else
283       return node.min_count;
284   }
285 
descend(node & n,float prediction)286   inline uint32_t descend(node& n, float prediction)
287   {
288     if (prediction < 0)
289       return n.left;
290     else
291       return n.right;
292   }
293 
predict(log_multi & b,base_learner & base,example & ec)294   void predict(log_multi& b,  base_learner& base, example& ec)
295   {
296     MULTICLASS::label_t mc = ec.l.multi;
297 
298     ec.l.simple = {FLT_MAX, 0.f, 0.f};
299     uint32_t cn = 0;
300     while(b.nodes[cn].internal)
301       {
302 	base.predict(ec, b.nodes[cn].base_predictor);
303 	cn = descend(b.nodes[cn], ec.pred.scalar);
304       }
305     ec.pred.multiclass = b.nodes[cn].max_count_label;
306     ec.l.multi = mc;
307   }
308 
learn(log_multi & b,base_learner & base,example & ec)309   void learn(log_multi& b, base_learner& base, example& ec)
310   {
311     //    verify_min_dfs(b, b.nodes[0]);
312     if (ec.l.multi.label == (uint32_t)-1 || b.progress)
313       predict(b,base,ec);
314 
315     if(ec.l.multi.label != (uint32_t)-1)	//if training the tree
316       {
317 	MULTICLASS::label_t mc = ec.l.multi;
318 	uint32_t start_pred = ec.pred.multiclass;
319 
320 	uint32_t class_index = 0;
321 	ec.l.simple = {FLT_MAX, mc.weight, 0.f};
322 	uint32_t cn = 0;
323 	while(children(b, cn, class_index, mc.label))
324 	  {
325 	    train_node(b, base, ec, cn, class_index);
326 	    cn = descend(b.nodes[cn], ec.pred.scalar);
327 	  }
328 
329 	b.nodes[cn].min_count++;
330 	update_min_count(b, cn);
331 	ec.pred.multiclass = start_pred;
332 	ec.l.multi = mc;
333       }
334   }
335 
save_node_stats(log_multi & d)336   void save_node_stats(log_multi& d)
337   {
338     FILE *fp;
339     uint32_t i, j;
340     uint32_t total;
341     log_multi* b = &d;
342 
343     fp = fopen("atxm_debug.csv", "wt");
344 
345     for(i = 0; i < b->nodes.size(); i++)
346       {
347 	fprintf(fp, "Node: %4d, Internal: %1d, Eh: %7.4f, n: %6d, \n", (int) i, (int) b->nodes[i].internal, b->nodes[i].Eh / b->nodes[i].n, b->nodes[i].n);
348 
349 	fprintf(fp, "Label:, ");
350 	for(j = 0; j < b->nodes[i].preds.size(); j++)
351 	  {
352 	    fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].label);
353 	  }
354 	fprintf(fp, "\n");
355 
356 	fprintf(fp, "Ehk:, ");
357 	for(j = 0; j < b->nodes[i].preds.size(); j++)
358 	  {
359 	    fprintf(fp, "%7.4f,", b->nodes[i].preds[j].Ehk / b->nodes[i].preds[j].nk);
360 	  }
361 	fprintf(fp, "\n");
362 
363 	total = 0;
364 
365 	fprintf(fp, "nk:, ");
366 	for(j = 0; j < b->nodes[i].preds.size(); j++)
367 	  {
368 	    fprintf(fp, "%6d,", (int) b->nodes[i].preds[j].nk);
369 	    total += b->nodes[i].preds[j].nk;
370 	  }
371 	fprintf(fp, "\n");
372 
373 	fprintf(fp, "max(lab:cnt:tot):, %3d,%6d,%7d,\n", (int) b->nodes[i].max_count_label, (int) b->nodes[i].max_count, (int) total);
374 	fprintf(fp, "left: %4d, right: %4d", (int) b->nodes[i].left, (int) b->nodes[i].right);
375 	fprintf(fp, "\n\n");
376       }
377 
378     fclose(fp);
379   }
380 
finish(log_multi & b)381   void finish(log_multi& b)
382   {
383     //save_node_stats(b);
384   }
385 
save_load_tree(log_multi & b,io_buf & model_file,bool read,bool text)386   void save_load_tree(log_multi& b, io_buf& model_file, bool read, bool text)
387   {
388     if (model_file.files.size() > 0)
389       {
390 	char buff[512];
391 
392 	uint32_t text_len = sprintf(buff, "k = %d ",b.k);
393 	bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.k), "", read, buff, text_len, text);
394 	uint32_t temp = (uint32_t)b.nodes.size();
395 	text_len = sprintf(buff, "nodes = %d ",temp);
396 	bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
397 	if (read)
398 	  for (uint32_t j = 1; j < temp; j++)
399 	    b.nodes.push_back(init_node());
400 	text_len = sprintf(buff, "max_predictors = %ld ",b.max_predictors);
401 	bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text);
402 
403 	text_len = sprintf(buff, "predictors_used = %ld ",b.predictors_used);
404 	bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text);
405 
406 	text_len = sprintf(buff, "progress = %d ",b.progress);
407 	bin_text_read_write_fixed(model_file,(char*)&b.progress, sizeof(b.progress), "", read, buff, text_len, text);
408 
409 	text_len = sprintf(buff, "swap_resist = %d\n",b.swap_resist);
410 	bin_text_read_write_fixed(model_file,(char*)&b.swap_resist, sizeof(b.swap_resist), "", read, buff, text_len, text);
411 
412 	for (size_t j = 0; j < b.nodes.size(); j++)
413 	  {//Need to read or write nodes.
414 	    node& n = b.nodes[j];
415 	    text_len = sprintf(buff, " parent = %d",n.parent);
416 	    bin_text_read_write_fixed(model_file,(char*)&n.parent, sizeof(n.parent), "", read, buff, text_len, text);
417 
418 	    uint32_t temp = (uint32_t)n.preds.size();
419 	    text_len = sprintf(buff, " preds = %d",temp);
420 	    bin_text_read_write_fixed(model_file,(char*)&temp, sizeof(temp), "", read, buff, text_len, text);
421 	    if (read)
422 	      for (uint32_t k = 0; k < temp; k++)
423 		n.preds.push_back(node_pred(1));
424 
425 	    text_len = sprintf(buff, " min_count = %d",n.min_count);
426 	    bin_text_read_write_fixed(model_file,(char*)&n.min_count, sizeof(n.min_count), "", read, buff, text_len, text);
427 
428 	    uint32_t text_len = sprintf(buff, " internal = %d",n.internal);
429 	    bin_text_read_write_fixed(model_file,(char*)&n.internal, sizeof(n.internal), "", read, buff, text_len, text)
430 ;
431 
432 	    if (n.internal)
433 	      {
434 		text_len = sprintf(buff, " base_predictor = %d",n.base_predictor);
435 		bin_text_read_write_fixed(model_file,(char*)&n.base_predictor, sizeof(n.base_predictor), "", read, buff, text_len, text);
436 
437 		text_len = sprintf(buff, " left = %d",n.left);
438 		bin_text_read_write_fixed(model_file,(char*)&n.left, sizeof(n.left), "", read, buff, text_len, text);
439 
440 		text_len = sprintf(buff, " right = %d",n.right);
441 		bin_text_read_write_fixed(model_file,(char*)&n.right, sizeof(n.right), "", read, buff, text_len, text);
442 
443 		text_len = sprintf(buff, " norm_Eh = %f",n.norm_Eh);
444 		bin_text_read_write_fixed(model_file,(char*)&n.norm_Eh, sizeof(n.norm_Eh), "", read, buff, text_len, text);
445 
446 		text_len = sprintf(buff, " Eh = %f",n.Eh);
447 		bin_text_read_write_fixed(model_file,(char*)&n.Eh, sizeof(n.Eh), "", read, buff, text_len, text);
448 
449 		text_len = sprintf(buff, " n = %d\n",n.n);
450 		bin_text_read_write_fixed(model_file,(char*)&n.n, sizeof(n.n), "", read, buff, text_len, text);
451 	      }
452 	    else
453 	      {
454 		text_len = sprintf(buff, " max_count = %d",n.max_count);
455 		bin_text_read_write_fixed(model_file,(char*)&n.max_count, sizeof(n.max_count), "", read, buff, text_len, text);
456 		text_len = sprintf(buff, " max_count_label = %d\n",n.max_count_label);
457 		bin_text_read_write_fixed(model_file,(char*)&n.max_count_label, sizeof(n.max_count_label), "", read, buff, text_len, text);
458 	      }
459 
460 	    for (size_t k = 0; k < n.preds.size(); k++)
461 	      {
462 		node_pred& p = n.preds[k];
463 
464 		text_len = sprintf(buff, "  Ehk = %f",p.Ehk);
465 		bin_text_read_write_fixed(model_file,(char*)&p.Ehk, sizeof(p.Ehk), "", read, buff, text_len, text);
466 
467 		text_len = sprintf(buff, " norm_Ehk = %f",p.norm_Ehk);
468 		bin_text_read_write_fixed(model_file,(char*)&p.norm_Ehk, sizeof(p.norm_Ehk), "", read, buff, text_len, text);
469 
470 		text_len = sprintf(buff, " nk = %d",p.nk);
471 		bin_text_read_write_fixed(model_file,(char*)&p.nk, sizeof(p.nk), "", read, buff, text_len, text);
472 
473 		text_len = sprintf(buff, " label = %d",p.label);
474 		bin_text_read_write_fixed(model_file,(char*)&p.label, sizeof(p.label), "", read, buff, text_len, text);
475 
476 		text_len = sprintf(buff, " label_count = %d\n",p.label_count);
477 		bin_text_read_write_fixed(model_file,(char*)&p.label_count, sizeof(p.label_count), "", read, buff, text_len, text);
478 	      }
479 	  }
480       }
481   }
482 
log_multi_setup(vw & all)483 base_learner* log_multi_setup(vw& all)	//learner setup
484 {
485   if (missing_option<size_t, true>(all, "log_multi", "Use online tree for multiclass"))
486     return NULL;
487   new_options(all, "Logarithmic Time Multiclass options")
488     ("no_progress", "disable progressive validation")
489     ("swap_resistance", po::value<uint32_t>(), "higher = more resistance to swap, default=4");
490   add_options(all);
491 
492   po::variables_map& vm = all.vm;
493 
494   log_multi& data = calloc_or_die<log_multi>();
495   data.k = (uint32_t)vm["log_multi"].as<size_t>();
496   data.swap_resist = 4;
497 
498   if (vm.count("swap_resistance"))
499     data.swap_resist = vm["swap_resistance"].as<uint32_t>();
500 
501   if (vm.count("no_progress"))
502     data.progress = false;
503   else
504     data.progress = true;
505 
506   string loss_function = "quantile";
507   float loss_parameter = 0.5;
508   delete(all.loss);
509   all.loss = getLossFunction(all, loss_function, loss_parameter);
510 
511   data.max_predictors = data.k - 1;
512   init_tree(data);
513 
514   learner<log_multi>& l = init_multiclass_learner(&data, setup_base(all), learn, predict, all.p, data.max_predictors);
515   l.set_save_load(save_load_tree);
516   l.set_finish(finish);
517 
518   return make_base(l);
519 }
520