1 /*
2   CoPyright (c) by respective owners including Yahoo!, Microsoft, and
3   individual contributors. All rights reserved.  Released under a BSD (revised)
4   license as described in the file LICENSE.
5 */
6 #include "search_entityrelationtask.h"
7 #include "vw.h"
8 
9 #define R_NONE 10 // label for NONE relation
10 #define LABEL_SKIP 11 // label for SKIP
11 
12 namespace EntityRelationTask { Search::search_task task = { "entity_relation", run, initialize, finish, NULL, NULL };  }
13 
14 
15 namespace EntityRelationTask {
16   namespace CS = COST_SENSITIVE;
17 
18   void update_example_indicies(bool audit, example* ec, uint32_t mult_amount, uint32_t plus_amount);
19   //enum SearchOrder { EntityFirst, Mix, Skip };
20 
21   struct task_data {
22     float relation_none_cost;
23     float entity_cost;
24     float relation_cost;
25     float skip_cost;
26     bool constraints;
27     bool allow_skip;
28     v_array<uint32_t> y_allowed_entity;
29     v_array<uint32_t> y_allowed_relation;
30     int search_order;
31     example* ldf_entity;
32     example* ldf_relation;
33     //SearchOrder search_order;
34   };
35 
36 
initialize(Search::search & sch,size_t & num_actions,po::variables_map & vm)37   void initialize(Search::search& sch, size_t& num_actions, po::variables_map& vm) {
38     task_data * my_task_data = new task_data();
39     po::options_description sspan_opts("entity relation options");
40     sspan_opts.add_options()
41         ("relation_cost", po::value<float>(&(my_task_data->relation_cost))->default_value(1.0), "Relation Cost")
42         ("entity_cost", po::value<float>(&(my_task_data->entity_cost))->default_value(1.0), "Entity Cost")
43         ("constraints", "Use Constraints")
44         ("relation_none_cost", po::value<float>(&(my_task_data->relation_none_cost))->default_value(0.5), "None Relation Cost")
45         ("skip_cost", po::value<float>(&(my_task_data->skip_cost))->default_value(0.01f), "Skip Cost (only used when search_order = skip")
46         ("search_order", po::value<int>(&(my_task_data->search_order))->default_value(0), "Search Order 0: EntityFirst 1: Mix 2: Skip 3: EntityFirst(LDF)" );
47     sch.add_program_options(vm, sspan_opts);
48 
49     // setup entity and relation labels
50     // Entity label 1:E_Other 2:E_Peop 3:E_Org 4:E_Loc
51     // Relation label 5:R_Live_in 6:R_OrgBased_in 7:R_Located_in 8:R_Work_For 9:R_Kill 10:R_None
52     my_task_data->constraints = vm.count("constraints") > 0;
53 
54     for(int i=1; i<5; i++)
55       my_task_data->y_allowed_entity.push_back(i);
56 
57     for(int i=5; i<11; i++)
58       my_task_data->y_allowed_relation.push_back(i);
59 
60     my_task_data->allow_skip = false;
61 
62     if(my_task_data->search_order != 3 && my_task_data->search_order != 4 ) {
63       sch.set_options(0);
64     } else {
65       example* ldf_examples = alloc_examples(sizeof(CS::label), 10);
66       CS::wclass default_wclass = { 0., 0, 0., 0. };
67       for (size_t a=0; a<10; a++) {
68         ldf_examples[a].l.cs.costs.push_back(default_wclass);
69       }
70       my_task_data->ldf_entity = ldf_examples;
71       my_task_data->ldf_relation = ldf_examples+4;
72       sch.set_options(Search::IS_LDF);
73     }
74 
75     sch.set_num_learners(2);
76     if(my_task_data->search_order == 4)
77       sch.set_num_learners(3);
78     sch.set_task_data<task_data>(my_task_data);
79   }
80 
finish(Search::search & sch)81   void finish(Search::search& sch) {
82     task_data * my_task_data = sch.get_task_data<task_data>();
83     my_task_data->y_allowed_entity.delete_v();
84     my_task_data->y_allowed_relation.delete_v();
85     if(my_task_data->search_order == 3) {
86       for (size_t a=0; a<10; a++)
87         dealloc_example(CS::cs_label.delete_label, my_task_data->ldf_entity[a]);
88       free(my_task_data->ldf_entity);
89     }
90     delete my_task_data;
91   }    // if we had task data, we'd want to free it here
92 
check_constraints(int ent1_id,int ent2_id,int rel_id)93   bool check_constraints(int ent1_id, int ent2_id, int rel_id){
94     int valid_ent1_id [] = {2,3,4,2,2}; // encode the valid entity-relation combinations
95     int valid_ent2_id [] = {4,4,4,3,2};
96     if(rel_id - 5 == 5)
97       return true;
98     if(valid_ent1_id[rel_id-5] == ent1_id && valid_ent2_id[rel_id-5] == ent2_id)
99       return true;
100     return false;
101   }
102 
decode_tag(v_array<char> tag,char & type,int & id1,int & id2)103   void decode_tag(v_array<char> tag, char& type, int& id1, int& id2){
104     string s1;
105     string s2;
106     type = tag[0];
107     uint32_t idx = 2;
108     while(idx < tag.size() && tag[idx] != '_' && tag[idx] != '\0'){
109       s1.push_back(tag[idx]);
110       idx++;
111     }
112     id1 = atoi(s1.c_str());
113     idx++;
114     assert(type == 'R');
115     while(idx < tag.size() && tag[idx] != '_' && tag[idx] != '\0'){
116       s2.push_back(tag[idx]);
117       idx++;
118     }
119     id2 = atoi(s2.c_str());
120   }
121 
predict_entity(Search::search & sch,example * ex,v_array<size_t> & predictions,ptag my_tag,bool isLdf=false)122   size_t predict_entity(Search::search&sch, example* ex, v_array<size_t>& predictions, ptag my_tag, bool isLdf=false){
123 
124     task_data* my_task_data = sch.get_task_data<task_data>();
125     size_t prediction;
126     if(my_task_data->allow_skip){
127       v_array<uint32_t> star_labels = v_init<uint32_t>();
128       star_labels.push_back(ex->l.multi.label);
129       star_labels.push_back(LABEL_SKIP);
130       my_task_data->y_allowed_entity.push_back(LABEL_SKIP);
131       prediction = Search::predictor(sch, my_tag).set_input(*ex).set_oracle(star_labels).set_allowed(my_task_data->y_allowed_entity).set_learner_id(1).predict();
132       my_task_data->y_allowed_entity.pop();
133     } else {
134       if(isLdf) {
135         for(size_t a=0; a<4; a++){
136           VW::copy_example_data(false, &my_task_data->ldf_entity[a], ex);
137           update_example_indicies(true, &my_task_data->ldf_entity[a], 28904713, 4832917 * (uint32_t)(a+1));
138           CS::label& lab = my_task_data->ldf_entity[a].l.cs;
139           lab.costs[0].x = 0.f;
140           lab.costs[0].class_index = (uint32_t)a;
141           lab.costs[0].partial_prediction = 0.f;
142           lab.costs[0].wap_value = 0.f;
143         }
144         prediction = Search::predictor(sch, my_tag).set_input(my_task_data->ldf_entity, 4).set_oracle(ex->l.multi.label-1).set_learner_id(1).predict() + 1;
145       } else {
146         prediction = Search::predictor(sch, my_tag).set_input(*ex).set_oracle(ex->l.multi.label).set_allowed(my_task_data->y_allowed_entity).set_learner_id(0).predict();
147       }
148     }
149 
150     // record loss
151     float loss = 0.0;
152     if(prediction == LABEL_SKIP){
153       loss = my_task_data->skip_cost;
154     } else if(prediction !=  ex->l.multi.label)
155       loss= my_task_data->entity_cost;
156     sch.loss(loss);
157     return prediction;
158   }
predict_relation(Search::search & sch,example * ex,v_array<size_t> & predictions,ptag my_tag,bool isLdf=false)159   size_t predict_relation(Search::search&sch, example* ex, v_array<size_t>& predictions, ptag my_tag, bool isLdf=false){
160     char type;
161     int id1, id2;
162     task_data* my_task_data = sch.get_task_data<task_data>();
163     uint32_t hist[2];
164     decode_tag(ex->tag, type, id1, id2);
165     v_array<uint32_t> constrained_relation_labels = v_init<uint32_t>();
166     if(my_task_data->constraints && predictions[id1]!=0 &&predictions[id2]!=0){
167       hist[0] = (uint32_t)predictions[id1];
168       hist[1] = (uint32_t)predictions[id2];
169     } else {
170       hist[0] = 0;
171     }
172     for(size_t j=0; j< my_task_data->y_allowed_relation.size(); j++){
173       if(!my_task_data->constraints || hist[0] == 0  || check_constraints(hist[0], hist[1], my_task_data->y_allowed_relation[j])){
174         constrained_relation_labels.push_back(my_task_data->y_allowed_relation[j]);
175       }
176     }
177 
178     size_t prediction;
179     if(my_task_data->allow_skip){
180       v_array<uint32_t> star_labels = v_init<uint32_t>();
181       star_labels.push_back(ex->l.multi.label);
182       star_labels.push_back(LABEL_SKIP);
183       constrained_relation_labels.push_back(LABEL_SKIP);
184       prediction = Search::predictor(sch, my_tag).set_input(*ex).set_oracle(star_labels).set_allowed(constrained_relation_labels).set_learner_id(2).add_condition(id1, 'a').add_condition(id2, 'b').predict();
185       constrained_relation_labels.pop();
186     } else {
187       if(isLdf) {
188         int correct_label = 0; // if correct label is not in the set, use the first one
189         for(size_t a=0; a<constrained_relation_labels.size(); a++){
190           VW::copy_example_data(false, &my_task_data->ldf_relation[a], ex);
191           update_example_indicies(true, &my_task_data->ldf_relation[a], 28904713, 4832917* (uint32_t)(constrained_relation_labels[a]));
192           CS::label& lab = my_task_data->ldf_relation[a].l.cs;
193           lab.costs[0].x = 0.f;
194           lab.costs[0].class_index = (uint32_t)constrained_relation_labels[a];
195           lab.costs[0].partial_prediction = 0.f;
196           lab.costs[0].wap_value = 0.f;
197           if(constrained_relation_labels[a] == ex->l.multi.label){
198             correct_label = (int)a;
199           }
200         }
201         size_t pred_pos = Search::predictor(sch, my_tag).set_input(my_task_data->ldf_relation, constrained_relation_labels.size()).set_oracle(correct_label).set_learner_id(2).predict();
202         prediction = constrained_relation_labels[pred_pos];
203       } else {
204         prediction = Search::predictor(sch, my_tag).set_input(*ex).set_oracle(ex->l.multi.label).set_allowed(constrained_relation_labels).set_learner_id(1).predict();
205       }
206     }
207 
208     float loss = 0.0;
209     if(prediction == LABEL_SKIP){
210       loss = my_task_data->skip_cost;
211     } else if(prediction !=  ex->l.multi.label) {
212       if(ex->l.multi.label == R_NONE){
213         loss = my_task_data->relation_none_cost;
214       } else {
215         loss= my_task_data->relation_cost;
216       }
217     }
218     sch.loss(loss);
219     return prediction;
220   }
221 
entity_first_decoding(Search::search & sch,vector<example * > ec,v_array<size_t> & predictions,bool isLdf=false)222   void entity_first_decoding(Search::search& sch, vector<example*> ec, v_array<size_t>& predictions, bool isLdf=false) {
223     // ec.size = #entity + #entity*(#entity-1)/2
224     size_t n_ent = (size_t)(sqrt(ec.size()*8+1)-1)/2;
225     // Do entity recognition first
226     for (size_t i=0; i<ec.size(); i++) {
227       if(i< n_ent)
228         predictions[i] = predict_entity(sch, ec[i], predictions, (ptag)i, isLdf);
229       else
230         predictions[i] = predict_relation(sch, ec[i], predictions, (ptag)i, isLdf);
231     }
232   }
233 
er_mixed_decoding(Search::search & sch,vector<example * > ec,v_array<size_t> & predictions)234   void er_mixed_decoding(Search::search& sch, vector<example*> ec, v_array<size_t>& predictions) {
235     // ec.size = #entity + #entity*(#entity-1)/2
236     size_t n_ent = (size_t)(sqrt(ec.size()*8+1)-1)/2;
237     for(size_t t=0; t<ec.size(); t++){
238       // Do entity recognition first
239       size_t count = 0;
240       for (size_t i=0; i<n_ent; i++) {
241         if(count ==t){
242           predictions[i] = predict_entity(sch, ec[i], predictions, (ptag)i);
243           break;
244         }
245         count++;
246         for(size_t j=0; j<i; j++) {
247           if(count ==t){
248             uint32_t rel_index = (uint32_t) (n_ent + (2*n_ent-j-1)*j/2 + i-j-1);
249             predictions[rel_index] = predict_relation(sch, ec[rel_index], predictions, rel_index);
250             break;
251           }
252           count++;
253         }
254       }
255     }
256   }
257 
er_allow_skip_decoding(Search::search & sch,vector<example * > ec,v_array<size_t> & predictions)258   void er_allow_skip_decoding(Search::search& sch, vector<example*> ec, v_array<size_t>& predictions) {
259     task_data* my_task_data = sch.get_task_data<task_data>();
260     // ec.size = #entity + #entity*(#entity-1)/2
261     size_t n_ent = (size_t)(sqrt(ec.size()*8+1)-1)/2;
262 
263     bool must_predict = false;
264     size_t n_predicts = 0;
265     size_t p_n_predicts = 0;
266     my_task_data->allow_skip = true;
267 
268     // loop until all the entity and relation types are predicted
269     for(size_t t=0; ; t++){
270       uint32_t i = (uint32_t) t % ec.size();
271       if(n_predicts == ec.size())
272         break;
273 
274       if(predictions[i] == 0){
275         if(must_predict) {
276           my_task_data->allow_skip = false;
277         }
278         size_t prediction = 0;
279         if(i < n_ent) {// do entity recognition
280           prediction = predict_entity(sch, ec[i], predictions, i);
281         } else { // do relation recognition
282           prediction = predict_relation(sch, ec[i], predictions, i);
283         }
284 
285         if(prediction != LABEL_SKIP){
286           predictions[i] = prediction;
287           n_predicts++;
288         }
289 
290         if(must_predict) {
291           my_task_data->allow_skip = true;
292           must_predict = false;
293         }
294       }
295 
296       if(i == ec.size()-1) {
297         if(n_predicts == p_n_predicts){
298           must_predict = true;
299         }
300         p_n_predicts = n_predicts;
301       }
302     }
303   }
304 
run(Search::search & sch,vector<example * > & ec)305   void run(Search::search& sch, vector<example*>& ec) {
306     task_data* my_task_data = sch.get_task_data<task_data>();
307 
308     v_array<size_t> predictions = v_init<size_t>();
309     for(size_t i=0; i<ec.size(); i++){
310       predictions.push_back(0);
311     }
312 
313     switch(my_task_data->search_order) {
314       case 0:
315         entity_first_decoding(sch, ec, predictions, false);
316         break;
317       case 1:
318         er_mixed_decoding(sch, ec, predictions);
319         break;
320       case 2:
321         er_allow_skip_decoding(sch, ec, predictions);
322         break;
323       case 3:
324         entity_first_decoding(sch, ec, predictions, true); //LDF = true
325         break;
326       default:
327         cerr << "search order " << my_task_data->search_order << "is undefined." << endl;
328     }
329 
330 
331     for(size_t i=0; i<ec.size(); i++){
332       if (sch.output().good())
333         sch.output() << predictions[i] << ' ';
334     }
335   }
336   // this is totally bogus for the example -- you'd never actually do this!
update_example_indicies(bool audit,example * ec,uint32_t mult_amount,uint32_t plus_amount)337   void update_example_indicies(bool audit, example* ec, uint32_t mult_amount, uint32_t plus_amount) {
338     for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
339       for (feature* f = ec->atomics[*i].begin; f != ec->atomics[*i].end; ++f)
340         f->weight_index = ((f->weight_index * mult_amount) + plus_amount);
341     if (audit)
342       for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
343         if (ec->audit_features[*i].begin != ec->audit_features[*i].end)
344           for (audit_data *f = ec->audit_features[*i].begin; f != ec->audit_features[*i].end; ++f)
345             f->weight_index = ((f->weight_index * mult_amount) + plus_amount);
346   }
347 }
348