1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5 */
6 /*
7 Initial implementation by Hal Daume and John Langford. Reimplementation
8 by John Langford.
9 */
10
11 #include <math.h>
12 #include <iostream>
13 #include <fstream>
14 #include <float.h>
15 #include <time.h>
16
17 #include "reductions.h"
18
19 using namespace std;
20 using namespace LEARNER;
21
22 struct direction {
23 size_t id; //unique id for node
24 size_t tournament; //unique id for node
25 uint32_t winner; //up traversal, winner
26 uint32_t loser; //up traversal, loser
27 uint32_t left; //down traversal, left
28 uint32_t right; //down traversal, right
29 bool last;
30 };
31
32 struct ect{
33 uint32_t k;
34 uint32_t errors;
35 v_array<direction> directions;//The nodes of the tournament datastructure
36
37 v_array<v_array<v_array<uint32_t > > > all_levels;
38
39 v_array<uint32_t> final_nodes; //The final nodes of each tournament.
40
41 v_array<size_t> up_directions; //On edge e, which node n is in the up direction?
42 v_array<size_t> down_directions;//On edge e, which node n is in the down direction?
43
44 size_t tree_height; //The height of the final tournament.
45
46 uint32_t last_pair;
47
48 v_array<bool> tournaments_won;
49 };
50
exists(v_array<size_t> db)51 bool exists(v_array<size_t> db)
52 {
53 for (size_t i = 0; i< db.size();i++)
54 if (db[i] != 0)
55 return true;
56 return false;
57 }
58
final_depth(size_t eliminations)59 size_t final_depth(size_t eliminations)
60 {
61 eliminations--;
62 for (size_t i = 0; i < 32; i++)
63 if (eliminations >> i == 0)
64 return i;
65 cerr << "too many eliminations" << endl;
66 return 31;
67 }
68
not_empty(v_array<v_array<uint32_t>> tournaments)69 bool not_empty(v_array<v_array<uint32_t > > tournaments)
70 {
71 for (size_t i = 0; i < tournaments.size(); i++)
72 {
73 if (tournaments[i].size() > 0)
74 return true;
75 }
76 return false;
77 }
78
print_level(v_array<v_array<uint32_t>> level)79 void print_level(v_array<v_array<uint32_t> > level)
80 {
81 for (size_t t = 0; t < level.size(); t++)
82 {
83 for (size_t i = 0; i < level[t].size(); i++)
84 cout << " " << level[t][i];
85 cout << " | ";
86 }
87 cout << endl;
88 }
89
create_circuit(ect & e,uint32_t max_label,uint32_t eliminations)90 size_t create_circuit(ect& e, uint32_t max_label, uint32_t eliminations)
91 {
92 if (max_label == 1)
93 return 0;
94
95 v_array<v_array<uint32_t > > tournaments = v_init<v_array<uint32_t > >();
96 v_array<uint32_t> t = v_init<uint32_t>();
97
98 for (uint32_t i = 0; i < max_label; i++)
99 {
100 t.push_back(i);
101 direction d = {i,0,0,0,0,0, false};
102 e.directions.push_back(d);
103 }
104
105 tournaments.push_back(t);
106
107 for (size_t i = 0; i < eliminations-1; i++)
108 tournaments.push_back(v_array<uint32_t>());
109
110 e.all_levels.push_back(tournaments);
111
112 size_t level = 0;
113
114 uint32_t node = (uint32_t)e.directions.size();
115
116 while (not_empty(e.all_levels[level]))
117 {
118 v_array<v_array<uint32_t > > new_tournaments = v_init<v_array<uint32_t>>();
119 tournaments = e.all_levels[level];
120
121 for (size_t t = 0; t < tournaments.size(); t++)
122 {
123 v_array<uint32_t> empty = v_init<uint32_t>();
124 new_tournaments.push_back(empty);
125 }
126
127 for (size_t t = 0; t < tournaments.size(); t++)
128 {
129 for (size_t j = 0; j < tournaments[t].size()/2; j++)
130 {
131 uint32_t id = node++;
132 uint32_t left = tournaments[t][2*j];
133 uint32_t right = tournaments[t][2*j+1];
134
135 direction d = {id,t,0,0,left,right, false};
136 e.directions.push_back(d);
137 uint32_t direction_index = (uint32_t)e.directions.size()-1;
138 if (e.directions[left].tournament == t)
139 e.directions[left].winner = direction_index;
140 else
141 e.directions[left].loser = direction_index;
142 if (e.directions[right].tournament == t)
143 e.directions[right].winner = direction_index;
144 else
145 e.directions[right].loser = direction_index;
146 if (e.directions[left].last == true)
147 e.directions[left].winner = direction_index;
148
149 if (tournaments[t].size() == 2 && (t == 0 || tournaments[t-1].size() == 0))
150 {
151 e.directions[direction_index].last = true;
152 if (t+1 < tournaments.size())
153 new_tournaments[t+1].push_back(id);
154 else // winner eliminated.
155 e.directions[direction_index].winner = 0;
156 e.final_nodes.push_back((uint32_t)(e.directions.size()- 1));
157 }
158 else
159 new_tournaments[t].push_back(id);
160 if (t+1 < tournaments.size())
161 new_tournaments[t+1].push_back(id);
162 else // loser eliminated.
163 e.directions[direction_index].loser = 0;
164 }
165 if (tournaments[t].size() % 2 == 1)
166 new_tournaments[t].push_back(tournaments[t].last());
167 }
168 e.all_levels.push_back(new_tournaments);
169 level++;
170 }
171
172 e.last_pair = (max_label - 1)*(eliminations);
173
174 if ( max_label > 1)
175 e.tree_height = final_depth(eliminations);
176
177 return e.last_pair + (eliminations-1);
178 }
179
ect_predict(ect & e,base_learner & base,example & ec)180 uint32_t ect_predict(ect& e, base_learner& base, example& ec)
181 {
182 if (e.k == (size_t)1)
183 return 1;
184
185 uint32_t finals_winner = 0;
186
187 //Binary final elimination tournament first
188 ec.l.simple = {FLT_MAX, 0., 0.};
189
190 for (size_t i = e.tree_height-1; i != (size_t)0 -1; i--)
191 {
192 if ((finals_winner | (((size_t)1) << i)) <= e.errors)
193 {// a real choice exists
194 uint32_t problem_number = e.last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; //This is unique.
195
196 base.learn(ec, problem_number);
197
198 if (ec.pred.scalar > 0.)
199 finals_winner = finals_winner | (((size_t)1) << i);
200 }
201 }
202
203 uint32_t id = e.final_nodes[finals_winner];
204 while (id >= e.k)
205 {
206 base.learn(ec, id - e.k);
207
208 if (ec.pred.scalar > 0.)
209 id = e.directions[id].right;
210 else
211 id = e.directions[id].left;
212 }
213 return id+1;
214 }
215
member(size_t t,v_array<size_t> ar)216 bool member(size_t t, v_array<size_t> ar)
217 {
218 for (size_t i = 0; i < ar.size(); i++)
219 if (ar[i] == t)
220 return true;
221 return false;
222 }
223
ect_train(ect & e,base_learner & base,example & ec)224 void ect_train(ect& e, base_learner& base, example& ec)
225 {
226 if (e.k == 1)//nothing to do
227 return;
228 MULTICLASS::label_t mc = ec.l.multi;
229
230 label_data simple_temp;
231
232 simple_temp.initial = 0.;
233 simple_temp.weight = mc.weight;
234
235 e.tournaments_won.erase();
236
237 uint32_t id = e.directions[mc.label - 1].winner;
238 bool left = e.directions[id].left == mc.label - 1;
239 do
240 {
241 if (left)
242 simple_temp.label = -1;
243 else
244 simple_temp.label = 1;
245
246 ec.l.simple = simple_temp;
247 base.learn(ec, id-e.k);
248 ec.l.simple.weight = 0.;
249 base.learn(ec, id-e.k);//inefficient, we should extract final prediction exactly.
250
251 bool won = ec.pred.scalar * simple_temp.label > 0;
252
253 if (won)
254 {
255 if (!e.directions[id].last)
256 left = e.directions[e.directions[id].winner].left == id;
257 else
258 e.tournaments_won.push_back(true);
259 id = e.directions[id].winner;
260 }
261 else
262 {
263 if (!e.directions[id].last)
264 {
265 left = e.directions[e.directions[id].loser].left == id;
266 if (e.directions[id].loser == 0)
267 e.tournaments_won.push_back(false);
268 }
269 else
270 e.tournaments_won.push_back(false);
271 id = e.directions[id].loser;
272 }
273 }
274 while(id != 0);
275
276 if (e.tournaments_won.size() < 1)
277 cout << "badness!" << endl;
278
279 //tournaments_won is a bit vector determining which tournaments the label won.
280 for (size_t i = 0; i < e.tree_height; i++)
281 {
282 for (uint32_t j = 0; j < e.tournaments_won.size()/2; j++)
283 {
284 bool left = e.tournaments_won[j*2];
285 bool right = e.tournaments_won[j*2+1];
286 if (left == right)//no query to do
287 e.tournaments_won[j] = left;
288 else //query to do
289 {
290 if (left)
291 simple_temp.label = -1;
292 else
293 simple_temp.label = 1;
294 simple_temp.weight = (float)(1 << (e.tree_height -i -1));
295 ec.l.simple = simple_temp;
296
297 uint32_t problem_number = e.last_pair + j*(1 << (i+1)) + (1 << i) -1;
298
299 base.learn(ec, problem_number);
300
301 if (ec.pred.scalar > 0.)
302 e.tournaments_won[j] = right;
303 else
304 e.tournaments_won[j] = left;
305 }
306 if (e.tournaments_won.size() %2 == 1)
307 e.tournaments_won[e.tournaments_won.size()/2] = e.tournaments_won[e.tournaments_won.size()-1];
308 e.tournaments_won.end = e.tournaments_won.begin+(1+e.tournaments_won.size())/2;
309 }
310 }
311 }
312
predict(ect & e,base_learner & base,example & ec)313 void predict(ect& e, base_learner& base, example& ec) {
314 MULTICLASS::label_t mc = ec.l.multi;
315 if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1))
316 cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl;
317 ec.pred.multiclass = ect_predict(e, base, ec);
318 ec.l.multi = mc;
319 }
320
learn(ect & e,base_learner & base,example & ec)321 void learn(ect& e, base_learner& base, example& ec)
322 {
323 MULTICLASS::label_t mc = ec.l.multi;
324 predict(e, base, ec);
325 uint32_t pred = ec.pred.multiclass;
326
327 if (mc.label != (uint32_t)-1)
328 ect_train(e, base, ec);
329 ec.l.multi = mc;
330 ec.pred.multiclass = pred;
331 }
332
finish(ect & e)333 void finish(ect& e)
334 {
335 for (size_t l = 0; l < e.all_levels.size(); l++)
336 {
337 for (size_t t = 0; t < e.all_levels[l].size(); t++)
338 e.all_levels[l][t].delete_v();
339 e.all_levels[l].delete_v();
340 }
341 e.final_nodes.delete_v();
342
343 e.up_directions.delete_v();
344
345 e.directions.delete_v();
346
347 e.down_directions.delete_v();
348
349 e.tournaments_won.delete_v();
350 }
351
ect_setup(vw & all)352 base_learner* ect_setup(vw& all)
353 {
354 if (missing_option<size_t, true>(all, "ect", "Error correcting tournament with <k> labels"))
355 return NULL;
356 new_options(all, "Error Correcting Tournament options")
357 ("error", po::value<size_t>()->default_value(0), "error in ECT");
358 add_options(all);
359
360 ect& data = calloc_or_die<ect>();
361 data.k = (int)all.vm["ect"].as<size_t>();
362 data.errors = (uint32_t)all.vm["error"].as<size_t>();
363 //append error flag to options_from_file so it is saved in regressor file later
364 *all.file_options << " --error " << data.errors;
365
366 size_t wpp = create_circuit(data, data.k, data.errors+1);
367
368 learner<ect>& l = init_multiclass_learner(&data, setup_base(all), learn, predict, all.p, wpp);
369 l.set_finish(finish);
370 return make_base(l);
371 }
372