1 /* 2 CoPyright (c) by respective owners including Yahoo!, Microsoft, and 3 individual contributors. All rights reserved. Released under a BSD (revised) 4 license as described in the file LICENSE. 5 */ 6 #include "search_entityrelationtask.h" 7 #include "vw.h" 8 9 #define R_NONE 10 // label for NONE relation 10 #define LABEL_SKIP 11 // label for SKIP 11 12 namespace EntityRelationTask { Search::search_task task = { "entity_relation", run, initialize, finish, NULL, NULL }; } 13 14 15 namespace EntityRelationTask { 16 namespace CS = COST_SENSITIVE; 17 18 void update_example_indicies(bool audit, example* ec, uint32_t mult_amount, uint32_t plus_amount); 19 //enum SearchOrder { EntityFirst, Mix, Skip }; 20 21 struct task_data { 22 float relation_none_cost; 23 float entity_cost; 24 float relation_cost; 25 float skip_cost; 26 bool constraints; 27 bool allow_skip; 28 v_array<uint32_t> y_allowed_entity; 29 v_array<uint32_t> y_allowed_relation; 30 int search_order; 31 example* ldf_entity; 32 example* ldf_relation; 33 //SearchOrder search_order; 34 }; 35 36 initialize(Search::search & sch,size_t & num_actions,po::variables_map & vm)37 void initialize(Search::search& sch, size_t& num_actions, po::variables_map& vm) { 38 task_data * my_task_data = new task_data(); 39 po::options_description sspan_opts("entity relation options"); 40 sspan_opts.add_options() 41 ("relation_cost", po::value<float>(&(my_task_data->relation_cost))->default_value(1.0), "Relation Cost") 42 ("entity_cost", po::value<float>(&(my_task_data->entity_cost))->default_value(1.0), "Entity Cost") 43 ("constraints", "Use Constraints") 44 ("relation_none_cost", po::value<float>(&(my_task_data->relation_none_cost))->default_value(0.5), "None Relation Cost") 45 ("skip_cost", po::value<float>(&(my_task_data->skip_cost))->default_value(0.01f), "Skip Cost (only used when search_order = skip") 46 ("search_order", po::value<int>(&(my_task_data->search_order))->default_value(0), "Search Order 0: EntityFirst 1: Mix 2: Skip 3: EntityFirst(LDF)" ); 47 sch.add_program_options(vm, sspan_opts); 48 49 // setup entity and relation labels 50 // Entity label 1:E_Other 2:E_Peop 3:E_Org 4:E_Loc 51 // Relation label 5:R_Live_in 6:R_OrgBased_in 7:R_Located_in 8:R_Work_For 9:R_Kill 10:R_None 52 my_task_data->constraints = vm.count("constraints") > 0; 53 54 for(int i=1; i<5; i++) 55 my_task_data->y_allowed_entity.push_back(i); 56 57 for(int i=5; i<11; i++) 58 my_task_data->y_allowed_relation.push_back(i); 59 60 my_task_data->allow_skip = false; 61 62 if(my_task_data->search_order != 3 && my_task_data->search_order != 4 ) { 63 sch.set_options(0); 64 } else { 65 example* ldf_examples = alloc_examples(sizeof(CS::label), 10); 66 CS::wclass default_wclass = { 0., 0, 0., 0. }; 67 for (size_t a=0; a<10; a++) { 68 ldf_examples[a].l.cs.costs.push_back(default_wclass); 69 } 70 my_task_data->ldf_entity = ldf_examples; 71 my_task_data->ldf_relation = ldf_examples+4; 72 sch.set_options(Search::IS_LDF); 73 } 74 75 sch.set_num_learners(2); 76 if(my_task_data->search_order == 4) 77 sch.set_num_learners(3); 78 sch.set_task_data<task_data>(my_task_data); 79 } 80 finish(Search::search & sch)81 void finish(Search::search& sch) { 82 task_data * my_task_data = sch.get_task_data<task_data>(); 83 my_task_data->y_allowed_entity.delete_v(); 84 my_task_data->y_allowed_relation.delete_v(); 85 if(my_task_data->search_order == 3) { 86 for (size_t a=0; a<10; a++) 87 dealloc_example(CS::cs_label.delete_label, my_task_data->ldf_entity[a]); 88 free(my_task_data->ldf_entity); 89 } 90 delete my_task_data; 91 } // if we had task data, we'd want to free it here 92 check_constraints(int ent1_id,int ent2_id,int rel_id)93 bool check_constraints(int ent1_id, int ent2_id, int rel_id){ 94 int valid_ent1_id [] = {2,3,4,2,2}; // encode the valid entity-relation combinations 95 int valid_ent2_id [] = {4,4,4,3,2}; 96 if(rel_id - 5 == 5) 97 return true; 98 if(valid_ent1_id[rel_id-5] == ent1_id && valid_ent2_id[rel_id-5] == ent2_id) 99 return true; 100 return false; 101 } 102 decode_tag(v_array<char> tag,char & type,int & id1,int & id2)103 void decode_tag(v_array<char> tag, char& type, int& id1, int& id2){ 104 string s1; 105 string s2; 106 type = tag[0]; 107 uint32_t idx = 2; 108 while(idx < tag.size() && tag[idx] != '_' && tag[idx] != '\0'){ 109 s1.push_back(tag[idx]); 110 idx++; 111 } 112 id1 = atoi(s1.c_str()); 113 idx++; 114 assert(type == 'R'); 115 while(idx < tag.size() && tag[idx] != '_' && tag[idx] != '\0'){ 116 s2.push_back(tag[idx]); 117 idx++; 118 } 119 id2 = atoi(s2.c_str()); 120 } 121 predict_entity(Search::search & sch,example * ex,v_array<size_t> & predictions,ptag my_tag,bool isLdf=false)122 size_t predict_entity(Search::search&sch, example* ex, v_array<size_t>& predictions, ptag my_tag, bool isLdf=false){ 123 124 task_data* my_task_data = sch.get_task_data<task_data>(); 125 size_t prediction; 126 if(my_task_data->allow_skip){ 127 v_array<uint32_t> star_labels = v_init<uint32_t>(); 128 star_labels.push_back(ex->l.multi.label); 129 star_labels.push_back(LABEL_SKIP); 130 my_task_data->y_allowed_entity.push_back(LABEL_SKIP); 131 prediction = Search::predictor(sch, my_tag).set_input(*ex).set_oracle(star_labels).set_allowed(my_task_data->y_allowed_entity).set_learner_id(1).predict(); 132 my_task_data->y_allowed_entity.pop(); 133 } else { 134 if(isLdf) { 135 for(size_t a=0; a<4; a++){ 136 VW::copy_example_data(false, &my_task_data->ldf_entity[a], ex); 137 update_example_indicies(true, &my_task_data->ldf_entity[a], 28904713, 4832917 * (uint32_t)(a+1)); 138 CS::label& lab = my_task_data->ldf_entity[a].l.cs; 139 lab.costs[0].x = 0.f; 140 lab.costs[0].class_index = (uint32_t)a; 141 lab.costs[0].partial_prediction = 0.f; 142 lab.costs[0].wap_value = 0.f; 143 } 144 prediction = Search::predictor(sch, my_tag).set_input(my_task_data->ldf_entity, 4).set_oracle(ex->l.multi.label-1).set_learner_id(1).predict() + 1; 145 } else { 146 prediction = Search::predictor(sch, my_tag).set_input(*ex).set_oracle(ex->l.multi.label).set_allowed(my_task_data->y_allowed_entity).set_learner_id(0).predict(); 147 } 148 } 149 150 // record loss 151 float loss = 0.0; 152 if(prediction == LABEL_SKIP){ 153 loss = my_task_data->skip_cost; 154 } else if(prediction != ex->l.multi.label) 155 loss= my_task_data->entity_cost; 156 sch.loss(loss); 157 return prediction; 158 } predict_relation(Search::search & sch,example * ex,v_array<size_t> & predictions,ptag my_tag,bool isLdf=false)159 size_t predict_relation(Search::search&sch, example* ex, v_array<size_t>& predictions, ptag my_tag, bool isLdf=false){ 160 char type; 161 int id1, id2; 162 task_data* my_task_data = sch.get_task_data<task_data>(); 163 uint32_t hist[2]; 164 decode_tag(ex->tag, type, id1, id2); 165 v_array<uint32_t> constrained_relation_labels = v_init<uint32_t>(); 166 if(my_task_data->constraints && predictions[id1]!=0 &&predictions[id2]!=0){ 167 hist[0] = (uint32_t)predictions[id1]; 168 hist[1] = (uint32_t)predictions[id2]; 169 } else { 170 hist[0] = 0; 171 } 172 for(size_t j=0; j< my_task_data->y_allowed_relation.size(); j++){ 173 if(!my_task_data->constraints || hist[0] == 0 || check_constraints(hist[0], hist[1], my_task_data->y_allowed_relation[j])){ 174 constrained_relation_labels.push_back(my_task_data->y_allowed_relation[j]); 175 } 176 } 177 178 size_t prediction; 179 if(my_task_data->allow_skip){ 180 v_array<uint32_t> star_labels = v_init<uint32_t>(); 181 star_labels.push_back(ex->l.multi.label); 182 star_labels.push_back(LABEL_SKIP); 183 constrained_relation_labels.push_back(LABEL_SKIP); 184 prediction = Search::predictor(sch, my_tag).set_input(*ex).set_oracle(star_labels).set_allowed(constrained_relation_labels).set_learner_id(2).add_condition(id1, 'a').add_condition(id2, 'b').predict(); 185 constrained_relation_labels.pop(); 186 } else { 187 if(isLdf) { 188 int correct_label = 0; // if correct label is not in the set, use the first one 189 for(size_t a=0; a<constrained_relation_labels.size(); a++){ 190 VW::copy_example_data(false, &my_task_data->ldf_relation[a], ex); 191 update_example_indicies(true, &my_task_data->ldf_relation[a], 28904713, 4832917* (uint32_t)(constrained_relation_labels[a])); 192 CS::label& lab = my_task_data->ldf_relation[a].l.cs; 193 lab.costs[0].x = 0.f; 194 lab.costs[0].class_index = (uint32_t)constrained_relation_labels[a]; 195 lab.costs[0].partial_prediction = 0.f; 196 lab.costs[0].wap_value = 0.f; 197 if(constrained_relation_labels[a] == ex->l.multi.label){ 198 correct_label = (int)a; 199 } 200 } 201 size_t pred_pos = Search::predictor(sch, my_tag).set_input(my_task_data->ldf_relation, constrained_relation_labels.size()).set_oracle(correct_label).set_learner_id(2).predict(); 202 prediction = constrained_relation_labels[pred_pos]; 203 } else { 204 prediction = Search::predictor(sch, my_tag).set_input(*ex).set_oracle(ex->l.multi.label).set_allowed(constrained_relation_labels).set_learner_id(1).predict(); 205 } 206 } 207 208 float loss = 0.0; 209 if(prediction == LABEL_SKIP){ 210 loss = my_task_data->skip_cost; 211 } else if(prediction != ex->l.multi.label) { 212 if(ex->l.multi.label == R_NONE){ 213 loss = my_task_data->relation_none_cost; 214 } else { 215 loss= my_task_data->relation_cost; 216 } 217 } 218 sch.loss(loss); 219 return prediction; 220 } 221 entity_first_decoding(Search::search & sch,vector<example * > ec,v_array<size_t> & predictions,bool isLdf=false)222 void entity_first_decoding(Search::search& sch, vector<example*> ec, v_array<size_t>& predictions, bool isLdf=false) { 223 // ec.size = #entity + #entity*(#entity-1)/2 224 size_t n_ent = (size_t)(sqrt(ec.size()*8+1)-1)/2; 225 // Do entity recognition first 226 for (size_t i=0; i<ec.size(); i++) { 227 if(i< n_ent) 228 predictions[i] = predict_entity(sch, ec[i], predictions, (ptag)i, isLdf); 229 else 230 predictions[i] = predict_relation(sch, ec[i], predictions, (ptag)i, isLdf); 231 } 232 } 233 er_mixed_decoding(Search::search & sch,vector<example * > ec,v_array<size_t> & predictions)234 void er_mixed_decoding(Search::search& sch, vector<example*> ec, v_array<size_t>& predictions) { 235 // ec.size = #entity + #entity*(#entity-1)/2 236 size_t n_ent = (size_t)(sqrt(ec.size()*8+1)-1)/2; 237 for(size_t t=0; t<ec.size(); t++){ 238 // Do entity recognition first 239 size_t count = 0; 240 for (size_t i=0; i<n_ent; i++) { 241 if(count ==t){ 242 predictions[i] = predict_entity(sch, ec[i], predictions, (ptag)i); 243 break; 244 } 245 count++; 246 for(size_t j=0; j<i; j++) { 247 if(count ==t){ 248 uint32_t rel_index = (uint32_t) (n_ent + (2*n_ent-j-1)*j/2 + i-j-1); 249 predictions[rel_index] = predict_relation(sch, ec[rel_index], predictions, rel_index); 250 break; 251 } 252 count++; 253 } 254 } 255 } 256 } 257 er_allow_skip_decoding(Search::search & sch,vector<example * > ec,v_array<size_t> & predictions)258 void er_allow_skip_decoding(Search::search& sch, vector<example*> ec, v_array<size_t>& predictions) { 259 task_data* my_task_data = sch.get_task_data<task_data>(); 260 // ec.size = #entity + #entity*(#entity-1)/2 261 size_t n_ent = (size_t)(sqrt(ec.size()*8+1)-1)/2; 262 263 bool must_predict = false; 264 size_t n_predicts = 0; 265 size_t p_n_predicts = 0; 266 my_task_data->allow_skip = true; 267 268 // loop until all the entity and relation types are predicted 269 for(size_t t=0; ; t++){ 270 uint32_t i = (uint32_t) t % ec.size(); 271 if(n_predicts == ec.size()) 272 break; 273 274 if(predictions[i] == 0){ 275 if(must_predict) { 276 my_task_data->allow_skip = false; 277 } 278 size_t prediction = 0; 279 if(i < n_ent) {// do entity recognition 280 prediction = predict_entity(sch, ec[i], predictions, i); 281 } else { // do relation recognition 282 prediction = predict_relation(sch, ec[i], predictions, i); 283 } 284 285 if(prediction != LABEL_SKIP){ 286 predictions[i] = prediction; 287 n_predicts++; 288 } 289 290 if(must_predict) { 291 my_task_data->allow_skip = true; 292 must_predict = false; 293 } 294 } 295 296 if(i == ec.size()-1) { 297 if(n_predicts == p_n_predicts){ 298 must_predict = true; 299 } 300 p_n_predicts = n_predicts; 301 } 302 } 303 } 304 run(Search::search & sch,vector<example * > & ec)305 void run(Search::search& sch, vector<example*>& ec) { 306 task_data* my_task_data = sch.get_task_data<task_data>(); 307 308 v_array<size_t> predictions = v_init<size_t>(); 309 for(size_t i=0; i<ec.size(); i++){ 310 predictions.push_back(0); 311 } 312 313 switch(my_task_data->search_order) { 314 case 0: 315 entity_first_decoding(sch, ec, predictions, false); 316 break; 317 case 1: 318 er_mixed_decoding(sch, ec, predictions); 319 break; 320 case 2: 321 er_allow_skip_decoding(sch, ec, predictions); 322 break; 323 case 3: 324 entity_first_decoding(sch, ec, predictions, true); //LDF = true 325 break; 326 default: 327 cerr << "search order " << my_task_data->search_order << "is undefined." << endl; 328 } 329 330 331 for(size_t i=0; i<ec.size(); i++){ 332 if (sch.output().good()) 333 sch.output() << predictions[i] << ' '; 334 } 335 } 336 // this is totally bogus for the example -- you'd never actually do this! update_example_indicies(bool audit,example * ec,uint32_t mult_amount,uint32_t plus_amount)337 void update_example_indicies(bool audit, example* ec, uint32_t mult_amount, uint32_t plus_amount) { 338 for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++) 339 for (feature* f = ec->atomics[*i].begin; f != ec->atomics[*i].end; ++f) 340 f->weight_index = ((f->weight_index * mult_amount) + plus_amount); 341 if (audit) 342 for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++) 343 if (ec->audit_features[*i].begin != ec->audit_features[*i].end) 344 for (audit_data *f = ec->audit_features[*i].begin; f != ec->audit_features[*i].end; ++f) 345 f->weight_index = ((f->weight_index * mult_amount) + plus_amount); 346 } 347 } 348