1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include "search_sequencetask.h"
7 #include "vw.h"
8 
9 namespace SequenceTask         { Search::search_task task = { "sequence",          run, initialize, NULL,   NULL,  NULL     }; }
10 namespace SequenceSpanTask     { Search::search_task task = { "sequencespan",      run, initialize, finish, setup, takedown }; }
11 namespace ArgmaxTask           { Search::search_task task = { "argmax",            run, initialize, NULL,   NULL,  NULL     }; }
12 namespace SequenceTask_DemoLDF { Search::search_task task = { "sequence_demoldf",  run, initialize, finish, NULL,  NULL     }; }
13 
14 namespace SequenceTask {
initialize(Search::search & sch,size_t & num_actions,po::variables_map & vm)15   void initialize(Search::search& sch, size_t& num_actions, po::variables_map& vm) {
16     sch.set_options( Search::AUTO_CONDITION_FEATURES  |    // automatically add history features to our examples, please
17                      Search::AUTO_HAMMING_LOSS        |    // please just use hamming loss on individual predictions -- we won't declare loss
18                      Search::EXAMPLES_DONT_CHANGE     |    // we don't do any internal example munging
19                      0);
20   }
21 
run(Search::search & sch,vector<example * > & ec)22   void run(Search::search& sch, vector<example*>& ec) {
23     for (size_t i=0; i<ec.size(); i++) {
24       action oracle     = ec[i]->l.multi.label;
25       size_t prediction = Search::predictor(sch, (ptag)i+1).set_input(*ec[i]).set_oracle(oracle).set_condition_range((ptag)i, sch.get_history_length(), 'p').predict();
26 
27       if (sch.output().good())
28         sch.output() << prediction << ' ';
29     }
30   }
31 }
32 
33 
34 namespace SequenceSpanTask {
35   enum EncodingType { BIO, BILOU };
36 // the format for the BIO encoding is:
37 //     label     description
38 //     1         "O" (out)
39 //     n even    begin X, where X is defined by n/2
40 //     n odd     in X, where X is (n-1)/2
41 //   thus, valid transitions are:
42 //     *       -> 1       (anything to OUT)
43 //     *       -> n even  (anything in BEGIN X)
44 //     n even  -> n+1     (BEGIN X to IN X)
45 //     n odd>1 -> n       (IN X to IN X)
46 // the format for the BILOU (begin, inside, last, out, unit-length) encoding is:
47 //     label     description
48 //     1         out
49 //     n>1: let m=n-2:
50 //       m % 4 == 0    unit-(m div 4)
51 //       m % 4 == 1    begin-(m div 4)
52 //       m % 4 == 2    in-(m div 4)
53 //       m % 4 == 3    last-(m div 4)
54 //   thus, valid transitions are:
55 //     1     -> 1; 2, 6, 10, ...; 3, 7, 11, ...         out to { out, unit-Y, begin-Y }       1
56 //     m%4=0 -> 1; 2, 6, 10, ..., 3, 7, 11, ...         unit-X to { out, unit-Y, begin-Y }    2, 6, 10, 14, ...
57 //     m%4=1 -> m+1, m+2                                begin-X to { in-X, last-X }           3, 7, 11, 15, ...
58 //     m%4=2 -> m, m+1                                  in-X to { in-X, last-X }              4, 8, 12, 16, ...
59 //     m%4=3 -> 1; 2, 6, 10, ...; 3, 7, 11, ...         last-X to { out, unit-Y, begin-Y }    5, 9, 13, 17, ...
60 
bilou_to_bio(action y)61   inline action bilou_to_bio(action y) {
62     return y / 2 + 1;  // out -> out, {unit,begin} -> begin; {in,last} -> in
63   }
64 
convert_bio_to_bilou(vector<example * > ec)65   void convert_bio_to_bilou(vector<example*> ec) {
66     for (size_t n=0; n<ec.size(); n++) {
67       MULTICLASS::label_t& ylab = ec[n]->l.multi;
68       action y = ylab.label;
69       action nexty = (n == ec.size()-1) ? 0 : ec[n+1]->l.multi.label;
70       if (y == 1) { // do nothing
71       } else if (y % 2 == 0) { // this is a begin-X
72         if (nexty != y + 1) // should be unit
73           ylab.label = (y/2 - 1) * 4 + 2;  // from 2 to 2, 4 to 6, 6 to 10, etc.
74         else // should be begin-X
75           ylab.label = (y/2 - 1) * 4 + 3;  // from 2 to 3, 4 to 7, 6 to 11, etc.
76       } else if (y % 2 == 1) { // this is an in-X
77         if (nexty != y) // should be last
78           ylab.label = (y-1) * 2 + 1;  // from 3 to 5, 5 to 9, 7 to 13, etc.
79         else // should be in-X
80           ylab.label = (y-1) * 2;      // from 3 to 4, 5 to 8, 7 to 12, etc.
81       }
82       assert( y == bilou_to_bio(ylab.label) );
83     }
84   }
85 
86   struct task_data {
87     EncodingType encoding;
88     v_array<action> allowed_actions;
89     v_array<action> only_two_allowed;  // used for BILOU encoding
90   };
91 
initialize(Search::search & sch,size_t & num_actions,po::variables_map & vm)92   void initialize(Search::search& sch, size_t& num_actions, po::variables_map& vm) {
93     task_data * my_task_data = new task_data();
94     po::options_description sspan_opts("search sequencespan options");
95     sspan_opts.add_options()("search_span_bilou", "switch to (internal) BILOU encoding instead of BIO encoding");
96     sch.add_program_options(vm, sspan_opts);
97 
98     if (vm.count("search_span_bilou")) {
99       cerr << "switching to BILOU encoding for sequence span labeling" << endl;
100       my_task_data->encoding = BILOU;
101       num_actions = num_actions * 2 - 1;
102     } else
103       my_task_data->encoding = BIO;
104 
105 
106     my_task_data->allowed_actions.erase();
107 
108     if (my_task_data->encoding == BIO) {
109       my_task_data->allowed_actions.push_back(1);
110       for (action l=2; l<num_actions; l+=2)
111         my_task_data->allowed_actions.push_back(l);
112       my_task_data->allowed_actions.push_back(1);  // push back an extra 1 that we can overwrite later if we want
113     } else if (my_task_data->encoding == BILOU) {
114       my_task_data->allowed_actions.push_back(1);
115       for (action l=2; l<num_actions; l+=4) {
116         my_task_data->allowed_actions.push_back(l);
117         my_task_data->allowed_actions.push_back(l+1);
118       }
119       my_task_data->only_two_allowed.push_back(0);
120       my_task_data->only_two_allowed.push_back(0);
121     }
122 
123     sch.set_task_data<task_data>(my_task_data);
124     sch.set_options( Search::AUTO_CONDITION_FEATURES  |    // automatically add history features to our examples, please
125                      Search::AUTO_HAMMING_LOSS        |    // please just use hamming loss on individual predictions -- we won't declare loss
126                      Search::EXAMPLES_DONT_CHANGE     |    // we don't do any internal example munging
127                      0);
128   }
129 
finish(Search::search & sch)130   void finish(Search::search& sch) {
131     task_data * my_task_data = sch.get_task_data<task_data>();
132     my_task_data->allowed_actions.delete_v();
133     my_task_data->only_two_allowed.delete_v();
134     delete my_task_data;
135   }
136 
setup(Search::search & sch,vector<example * > & ec)137   void setup(Search::search& sch, vector<example*>& ec) {
138     task_data * my_task_data = sch.get_task_data<task_data>();
139     if (my_task_data->encoding == BILOU)
140       convert_bio_to_bilou(ec);
141   }
142 
takedown(Search::search & sch,vector<example * > & ec)143   void takedown(Search::search& sch, vector<example*>& ec) {
144     task_data * my_task_data = sch.get_task_data<task_data>();
145 
146     if (my_task_data->encoding == BILOU)
147       for (size_t n=0; n<ec.size(); n++) {
148         MULTICLASS::label_t ylab = ec[n]->l.multi;
149         ylab.label = bilou_to_bio(ylab.label);
150       }
151   }
152 
run(Search::search & sch,vector<example * > & ec)153   void run(Search::search& sch, vector<example*>& ec) {
154     task_data * my_task_data = sch.get_task_data<task_data>();
155     action last_prediction = 1;
156     v_array<action> * y_allowed = &(my_task_data->allowed_actions);
157 
158     for (size_t i=0; i<ec.size(); i++) {
159       action oracle = ec[i]->l.multi.label;
160       size_t len = y_allowed->size();
161       Search::predictor P(sch, (ptag)i+1);
162       if (my_task_data->encoding == BIO) {
163         if      (last_prediction == 1)       P.set_allowed(y_allowed->begin, len-1);
164         else if (last_prediction % 2 == 0) { (*y_allowed)[len-1] = last_prediction+1; P.set_allowed(*y_allowed); }
165         else                               { (*y_allowed)[len-1] = last_prediction;   P.set_allowed(*y_allowed); }
166         if ((oracle > 1) && (oracle % 2 == 1) && (last_prediction != oracle) && (last_prediction != oracle-1))
167           oracle = 1; // if we are supposed to I-X, but last wasn't B-X or I-X, then say O
168       } else if (my_task_data->encoding == BILOU) {
169         if ((last_prediction == 1) || ((last_prediction-2) % 4 == 0) || ((last_prediction-2) % 4 == 3)) { // O or unit-X or last-X
170           P.set_allowed(my_task_data->allowed_actions);
171           // we cannot allow in-X or last-X next
172           if ((oracle > 1) && (((oracle-2) % 4 == 2) || ((oracle-2) % 4 == 3)))
173             oracle = 1;
174         } else { // begin-X or in-X
175           action other = ((last_prediction-2) % 4 == 1) ? (last_prediction+2) : last_prediction;
176           P.set_allowed(last_prediction+1);
177           P.add_allowed(other);
178           if ((oracle != last_prediction+1) && (oracle != other))
179             oracle = other;
180         }
181       }
182       last_prediction = P.set_input(*ec[i]).set_condition_range((ptag)i, sch.get_history_length(), 'p').set_oracle(oracle).predict();
183 
184       action printed_prediction = (my_task_data->encoding == BIO) ? last_prediction : bilou_to_bio(last_prediction);
185 
186       if (sch.output().good())
187         sch.output() << printed_prediction << ' ';
188     }
189   }
190 }
191 
192 namespace ArgmaxTask {
193   struct task_data {
194     float false_negative_cost;
195     float negative_weight;
196     bool predict_max;
197   };
198 
initialize(Search::search & sch,size_t & num_actions,po::variables_map & vm)199   void initialize(Search::search& sch, size_t& num_actions, po::variables_map& vm) {
200     task_data* my_task_data = new task_data();
201 
202     po::options_description argmax_opts("argmax options");
203     argmax_opts.add_options()
204       ("cost", po::value<float>(&(my_task_data->false_negative_cost))->default_value(10.0), "False Negative Cost")
205       ("negative_weight", po::value<float>(&(my_task_data->negative_weight))->default_value(1), "Relative weight of negative examples")
206       ("max", "Disable structure: just predict the max");
207     sch.add_program_options(vm, argmax_opts);
208 
209     my_task_data->predict_max = vm.count("max") > 0;
210 
211     sch.set_task_data(my_task_data);
212 
213     if (my_task_data->predict_max)
214       sch.set_options( Search::EXAMPLES_DONT_CHANGE );   // we don't do any internal example munging
215     else
216       sch.set_options( Search::AUTO_CONDITION_FEATURES |    // automatically add history features to our examples, please
217                        Search::EXAMPLES_DONT_CHANGE );   // we don't do any internal example munging
218   }
219 
run(Search::search & sch,vector<example * > & ec)220   void run(Search::search& sch, vector<example*>& ec) {
221     task_data * my_task_data = sch.get_task_data<task_data>();
222     uint32_t max_prediction = 1;
223     uint32_t max_label = 1;
224 
225     for(size_t i = 0; i < ec.size(); i++)
226       max_label = max(ec[i]->l.multi.label, max_label);
227 
228     for (ptag i=0; i<ec.size(); i++) {
229       // labels should be 1 or 2, and our output is MAX of all predicted values
230       uint32_t oracle = my_task_data->predict_max ? max_label : ec[i]->l.multi.label;
231       uint32_t prediction = sch.predict(*ec[i], i+1, &oracle, 1, &i, "p");
232 
233       max_prediction = max(prediction, max_prediction);
234     }
235     float loss = 0.;
236     if (max_label > max_prediction)
237       loss = my_task_data->false_negative_cost / my_task_data->negative_weight;
238     else if (max_prediction > max_label)
239       loss = 1.;
240     sch.loss(loss);
241 
242     if (sch.output().good())
243       sch.output() << max_prediction;
244   }
245 }
246 
247 
248 namespace SequenceTask_DemoLDF {  // this is just to debug/show off how to do LDF
249   namespace CS=COST_SENSITIVE;
250   struct task_data {
251     example* ldf_examples;
252     size_t   num_actions;
253   };
254 
initialize(Search::search & sch,size_t & num_actions,po::variables_map & vm)255   void initialize(Search::search& sch, size_t& num_actions, po::variables_map& vm) {
256     CS::wclass default_wclass = { 0., 0, 0., 0. };
257 
258     example* ldf_examples = alloc_examples(sizeof(CS::label), num_actions);
259     for (size_t a=0; a<num_actions; a++) {
260       CS::label& lab = ldf_examples[a].l.cs;
261       CS::cs_label.default_label(&lab);
262       lab.costs.push_back(default_wclass);
263     }
264 
265     task_data* data = &calloc_or_die<task_data>();
266     data->ldf_examples = ldf_examples;
267     data->num_actions  = num_actions;
268 
269     sch.set_task_data<task_data>(data);
270     sch.set_options( Search::AUTO_CONDITION_FEATURES |    // automatically add history features to our examples, please
271                      Search::AUTO_HAMMING_LOSS       |    // please just use hamming loss on individual predictions -- we won't declare loss
272                      Search::IS_LDF                  );   // we generate ldf examples
273   }
274 
finish(Search::search & sch)275   void finish(Search::search& sch) {
276     task_data *data = sch.get_task_data<task_data>();
277     for (size_t a=0; a<data->num_actions; a++)
278       dealloc_example(CS::cs_label.delete_label, data->ldf_examples[a]);
279     free(data->ldf_examples);
280     free(data);
281   }
282 
283 
284   // this is totally bogus for the example -- you'd never actually do this!
my_update_example_indicies(Search::search & sch,bool audit,example * ec,uint32_t mult_amount,uint32_t plus_amount)285   void my_update_example_indicies(Search::search& sch, bool audit, example* ec, uint32_t mult_amount, uint32_t plus_amount) {
286     size_t ss = sch.get_stride_shift();
287     for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
288       for (feature* f = ec->atomics[*i].begin; f != ec->atomics[*i].end; ++f)
289         f->weight_index = (((f->weight_index>>ss) * mult_amount) + plus_amount)<<ss;
290     if (audit)
291       for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
292         if (ec->audit_features[*i].begin != ec->audit_features[*i].end)
293           for (audit_data *f = ec->audit_features[*i].begin; f != ec->audit_features[*i].end; ++f)
294             f->weight_index = (((f->weight_index>>ss) * mult_amount) + plus_amount)<<ss;
295   }
296 
run(Search::search & sch,vector<example * > & ec)297   void run(Search::search& sch, vector<example*>& ec) {
298     task_data *data = sch.get_task_data<task_data>();
299     for (ptag i=0; i<ec.size(); i++) {
300       for (size_t a=0; a<data->num_actions; a++) {
301         if (sch.predictNeedsExample()) { // we can skip this work if `predict` won't actually use the example data
302           VW::copy_example_data(false, &data->ldf_examples[a], ec[i]);  // copy but leave label alone!
303           // now, offset it appropriately for the action id
304           my_update_example_indicies(sch, true, &data->ldf_examples[a], 28904713, 4832917 * (uint32_t)a);
305         }
306 
307         // regardless of whether the example is needed or not, the class info is needed
308         CS::label& lab = data->ldf_examples[a].l.cs;
309         // need to tell search what the action id is, so that it can add history features correctly!
310         lab.costs[0].x = 0.;
311         lab.costs[0].class_index = (uint32_t)a+1;
312         lab.costs[0].partial_prediction = 0.;
313         lab.costs[0].wap_value = 0.;
314       }
315 
316       action oracle  = ec[i]->l.multi.label - 1;
317       action pred_id = Search::predictor(sch, i+1).set_input(data->ldf_examples, data->num_actions).set_oracle(oracle).set_condition_range(i, sch.get_history_length(), 'p').predict();
318       action prediction = pred_id + 1;  // or ldf_examples[pred_id]->ld.costs[0].weight_index
319 
320       if (sch.output().good())
321         sch.output() << prediction << ' ';
322     }
323   }
324 }
325