1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 /*
7   Initial implementation by Hal Daume and John Langford.  Reimplementation
8   by John Langford.
9 */
10 
11 #include <math.h>
12 #include <iostream>
13 #include <fstream>
14 #include <float.h>
15 #include <time.h>
16 
17 #include "reductions.h"
18 
19 using namespace std;
20 using namespace LEARNER;
21 
22 struct direction {
23   size_t id; //unique id for node
24   size_t tournament; //unique id for node
25   uint32_t winner; //up traversal, winner
26   uint32_t loser; //up traversal, loser
27   uint32_t left; //down traversal, left
28   uint32_t right; //down traversal, right
29   bool last;
30 };
31 
32 struct ect{
33   uint32_t k;
34   uint32_t errors;
35   v_array<direction> directions;//The nodes of the tournament datastructure
36 
37   v_array<v_array<v_array<uint32_t > > > all_levels;
38 
39   v_array<uint32_t> final_nodes; //The final nodes of each tournament.
40 
41   v_array<size_t> up_directions; //On edge e, which node n is in the up direction?
42   v_array<size_t> down_directions;//On edge e, which node n is in the down direction?
43 
44   size_t tree_height; //The height of the final tournament.
45 
46   uint32_t last_pair;
47 
48   v_array<bool> tournaments_won;
49 };
50 
exists(v_array<size_t> db)51 bool exists(v_array<size_t> db)
52 {
53   for (size_t i = 0; i< db.size();i++)
54     if (db[i] != 0)
55       return true;
56   return false;
57 }
58 
final_depth(size_t eliminations)59 size_t final_depth(size_t eliminations)
60 {
61   eliminations--;
62   for (size_t i = 0; i < 32; i++)
63     if (eliminations >> i == 0)
64       return i;
65   cerr << "too many eliminations" << endl;
66   return 31;
67 }
68 
not_empty(v_array<v_array<uint32_t>> tournaments)69 bool not_empty(v_array<v_array<uint32_t > > tournaments)
70 {
71   for (size_t i = 0; i < tournaments.size(); i++)
72     {
73       if (tournaments[i].size() > 0)
74         return true;
75     }
76   return false;
77 }
78 
print_level(v_array<v_array<uint32_t>> level)79 void print_level(v_array<v_array<uint32_t> > level)
80 {
81   for (size_t t = 0; t < level.size(); t++)
82     {
83       for (size_t i = 0; i < level[t].size(); i++)
84 	cout << " " << level[t][i];
85       cout << " | ";
86     }
87   cout << endl;
88 }
89 
create_circuit(ect & e,uint32_t max_label,uint32_t eliminations)90 size_t create_circuit(ect& e, uint32_t max_label, uint32_t eliminations)
91 {
92   if (max_label == 1)
93     return 0;
94 
95   v_array<v_array<uint32_t > > tournaments = v_init<v_array<uint32_t > >();
96   v_array<uint32_t> t = v_init<uint32_t>();
97 
98   for (uint32_t i = 0; i < max_label; i++)
99     {
100       t.push_back(i);
101       direction d = {i,0,0,0,0,0, false};
102       e.directions.push_back(d);
103     }
104 
105   tournaments.push_back(t);
106 
107   for (size_t i = 0; i < eliminations-1; i++)
108     tournaments.push_back(v_array<uint32_t>());
109 
110   e.all_levels.push_back(tournaments);
111 
112   size_t level = 0;
113 
114   uint32_t node = (uint32_t)e.directions.size();
115 
116   while (not_empty(e.all_levels[level]))
117     {
118       v_array<v_array<uint32_t > > new_tournaments = v_init<v_array<uint32_t>>();
119       tournaments = e.all_levels[level];
120 
121       for (size_t t = 0; t < tournaments.size(); t++)
122 	{
123 	  v_array<uint32_t> empty = v_init<uint32_t>();
124 	  new_tournaments.push_back(empty);
125 	}
126 
127       for (size_t t = 0; t < tournaments.size(); t++)
128 	{
129 	  for (size_t j = 0; j < tournaments[t].size()/2; j++)
130 	    {
131 	      uint32_t id = node++;
132 	      uint32_t left = tournaments[t][2*j];
133 	      uint32_t right = tournaments[t][2*j+1];
134 
135 	      direction d = {id,t,0,0,left,right, false};
136 	      e.directions.push_back(d);
137 	      uint32_t direction_index = (uint32_t)e.directions.size()-1;
138 		if (e.directions[left].tournament == t)
139 		  e.directions[left].winner = direction_index;
140 		else
141 		  e.directions[left].loser = direction_index;
142 		if (e.directions[right].tournament == t)
143 		  e.directions[right].winner = direction_index;
144 		else
145 		  e.directions[right].loser = direction_index;
146 		if (e.directions[left].last == true)
147 		  e.directions[left].winner = direction_index;
148 
149 		if (tournaments[t].size() == 2 && (t == 0 || tournaments[t-1].size() == 0))
150 		  {
151 		    e.directions[direction_index].last = true;
152 		    if (t+1 < tournaments.size())
153 		      new_tournaments[t+1].push_back(id);
154 		    else // winner eliminated.
155 		      e.directions[direction_index].winner = 0;
156 		    e.final_nodes.push_back((uint32_t)(e.directions.size()- 1));
157 		  }
158 		else
159 		  new_tournaments[t].push_back(id);
160 		if (t+1 < tournaments.size())
161 		  new_tournaments[t+1].push_back(id);
162 		else // loser eliminated.
163 		  e.directions[direction_index].loser = 0;
164 	      }
165 	    if (tournaments[t].size() % 2 == 1)
166 	      new_tournaments[t].push_back(tournaments[t].last());
167 	  }
168 	e.all_levels.push_back(new_tournaments);
169 	level++;
170       }
171 
172     e.last_pair = (max_label - 1)*(eliminations);
173 
174     if ( max_label > 1)
175       e.tree_height = final_depth(eliminations);
176 
177     return e.last_pair + (eliminations-1);
178   }
179 
ect_predict(ect & e,base_learner & base,example & ec)180   uint32_t ect_predict(ect& e, base_learner& base, example& ec)
181   {
182     if (e.k == (size_t)1)
183       return 1;
184 
185     uint32_t finals_winner = 0;
186 
187     //Binary final elimination tournament first
188     ec.l.simple = {FLT_MAX, 0., 0.};
189 
190     for (size_t i = e.tree_height-1; i != (size_t)0 -1; i--)
191       {
192         if ((finals_winner | (((size_t)1) << i)) <= e.errors)
193           {// a real choice exists
194             uint32_t problem_number = e.last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; //This is unique.
195 
196             base.learn(ec, problem_number);
197 
198 	    if (ec.pred.scalar > 0.)
199               finals_winner = finals_winner | (((size_t)1) << i);
200           }
201       }
202 
203     uint32_t id = e.final_nodes[finals_winner];
204     while (id >= e.k)
205       {
206 	base.learn(ec, id - e.k);
207 
208 	if (ec.pred.scalar > 0.)
209 	  id = e.directions[id].right;
210 	else
211 	  id = e.directions[id].left;
212       }
213     return id+1;
214   }
215 
member(size_t t,v_array<size_t> ar)216   bool member(size_t t, v_array<size_t> ar)
217   {
218     for (size_t i = 0; i < ar.size(); i++)
219       if (ar[i] == t)
220         return true;
221     return false;
222   }
223 
ect_train(ect & e,base_learner & base,example & ec)224   void ect_train(ect& e, base_learner& base, example& ec)
225   {
226     if (e.k == 1)//nothing to do
227       return;
228     MULTICLASS::label_t mc = ec.l.multi;
229 
230     label_data simple_temp;
231 
232     simple_temp.initial = 0.;
233     simple_temp.weight = mc.weight;
234 
235     e.tournaments_won.erase();
236 
237     uint32_t id = e.directions[mc.label - 1].winner;
238     bool left = e.directions[id].left == mc.label - 1;
239     do
240       {
241 	if (left)
242 	  simple_temp.label = -1;
243 	else
244 	  simple_temp.label = 1;
245 
246 	ec.l.simple = simple_temp;
247 	base.learn(ec, id-e.k);
248 	ec.l.simple.weight = 0.;
249 	base.learn(ec, id-e.k);//inefficient, we should extract final prediction exactly.
250 
251 	bool won = ec.pred.scalar * simple_temp.label > 0;
252 
253 	if (won)
254 	  {
255 	    if (!e.directions[id].last)
256 	      left = e.directions[e.directions[id].winner].left == id;
257 	    else
258 	      e.tournaments_won.push_back(true);
259 	    id = e.directions[id].winner;
260 	  }
261 	else
262 	  {
263 	    if (!e.directions[id].last)
264 	      {
265 		left = e.directions[e.directions[id].loser].left == id;
266 		if (e.directions[id].loser == 0)
267 		  e.tournaments_won.push_back(false);
268 	      }
269 	    else
270 	      e.tournaments_won.push_back(false);
271 	    id = e.directions[id].loser;
272 	  }
273       }
274     while(id != 0);
275 
276     if (e.tournaments_won.size() < 1)
277       cout << "badness!" << endl;
278 
279     //tournaments_won is a bit vector determining which tournaments the label won.
280     for (size_t i = 0; i < e.tree_height; i++)
281       {
282         for (uint32_t j = 0; j < e.tournaments_won.size()/2; j++)
283           {
284             bool left = e.tournaments_won[j*2];
285             bool right = e.tournaments_won[j*2+1];
286             if (left == right)//no query to do
287               e.tournaments_won[j] = left;
288             else //query to do
289               {
290                 if (left)
291                   simple_temp.label = -1;
292                 else
293                   simple_temp.label = 1;
294 		simple_temp.weight = (float)(1 << (e.tree_height -i -1));
295                 ec.l.simple = simple_temp;
296 
297                 uint32_t problem_number = e.last_pair + j*(1 << (i+1)) + (1 << i) -1;
298 
299 		base.learn(ec, problem_number);
300 
301 		if (ec.pred.scalar > 0.)
302                   e.tournaments_won[j] = right;
303                 else
304                   e.tournaments_won[j] = left;
305               }
306             if (e.tournaments_won.size() %2 == 1)
307               e.tournaments_won[e.tournaments_won.size()/2] = e.tournaments_won[e.tournaments_won.size()-1];
308             e.tournaments_won.end = e.tournaments_won.begin+(1+e.tournaments_won.size())/2;
309           }
310       }
311   }
312 
predict(ect & e,base_learner & base,example & ec)313   void predict(ect& e, base_learner& base, example& ec) {
314     MULTICLASS::label_t mc = ec.l.multi;
315     if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1))
316       cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl;
317     ec.pred.multiclass = ect_predict(e, base, ec);
318     ec.l.multi = mc;
319   }
320 
learn(ect & e,base_learner & base,example & ec)321   void learn(ect& e, base_learner& base, example& ec)
322   {
323     MULTICLASS::label_t mc = ec.l.multi;
324     predict(e, base, ec);
325     uint32_t pred = ec.pred.multiclass;
326 
327     if (mc.label != (uint32_t)-1)
328       ect_train(e, base, ec);
329     ec.l.multi = mc;
330     ec.pred.multiclass = pred;
331   }
332 
finish(ect & e)333   void finish(ect& e)
334   {
335     for (size_t l = 0; l < e.all_levels.size(); l++)
336       {
337 	for (size_t t = 0; t < e.all_levels[l].size(); t++)
338 	  e.all_levels[l][t].delete_v();
339 	e.all_levels[l].delete_v();
340       }
341     e.final_nodes.delete_v();
342 
343     e.up_directions.delete_v();
344 
345     e.directions.delete_v();
346 
347     e.down_directions.delete_v();
348 
349     e.tournaments_won.delete_v();
350   }
351 
ect_setup(vw & all)352 base_learner* ect_setup(vw& all)
353 {
354   if (missing_option<size_t, true>(all, "ect", "Error correcting tournament with <k> labels"))
355     return NULL;
356   new_options(all, "Error Correcting Tournament options")
357     ("error", po::value<size_t>()->default_value(0), "error in ECT");
358   add_options(all);
359 
360   ect& data = calloc_or_die<ect>();
361   data.k = (int)all.vm["ect"].as<size_t>();
362   data.errors = (uint32_t)all.vm["error"].as<size_t>();
363   //append error flag to options_from_file so it is saved in regressor file later
364   *all.file_options << " --error " << data.errors;
365 
366   size_t wpp = create_circuit(data, data.k, data.errors+1);
367 
368   learner<ect>& l = init_multiclass_learner(&data, setup_base(all), learn, predict, all.p, wpp);
369   l.set_finish(finish);
370   return make_base(l);
371 }
372