1 /* 2 Copyright (c) by respective owners including Yahoo!, Microsoft, and 3 individual contributors. All rights reserved. Released under a BSD (revised) 4 license as described in the file LICENSE. 5 */ 6 #include <float.h> 7 #include <string.h> 8 #include <math.h> 9 #include "vw.h" 10 #include "rand48.h" 11 #include "reductions.h" 12 #include "gd.h" // for GD::foreach_feature 13 #include "search_sequencetask.h" 14 #include "search_multiclasstask.h" 15 #include "search_dep_parser.h" 16 #include "search_entityrelationtask.h" 17 #include "search_hooktask.h" 18 #include "search_graph.h" 19 #include "csoaa.h" 20 #include "beam.h" 21 22 using namespace LEARNER; 23 using namespace std; 24 namespace CS = COST_SENSITIVE; 25 namespace MC = MULTICLASS; 26 27 namespace Search { 28 search_task* all_tasks[] = { &SequenceTask::task, 29 &SequenceSpanTask::task, 30 &ArgmaxTask::task, 31 &SequenceTask_DemoLDF::task, 32 &MulticlassTask::task, 33 &DepParserTask::task, 34 &EntityRelationTask::task, 35 &HookTask::task, 36 &GraphTask::task, 37 NULL }; // must NULL terminate! 38 39 const bool PRINT_UPDATE_EVERY_EXAMPLE =0; 40 const bool PRINT_UPDATE_EVERY_PASS =0; 41 const bool PRINT_CLOCK_TIME =0; 42 43 string neighbor_feature_space("neighbor"); 44 string condition_feature_space("search_condition"); 45 46 uint32_t AUTO_CONDITION_FEATURES = 1, AUTO_HAMMING_LOSS = 2, EXAMPLES_DONT_CHANGE = 4, IS_LDF = 8; 47 enum SearchState { INITIALIZE, INIT_TEST, INIT_TRAIN, LEARN, GET_TRUTH_STRING }; 48 enum RollMethod { POLICY, ORACLE, MIX_PER_STATE, MIX_PER_ROLL, NO_ROLLOUT }; 49 50 // a data structure to hold conditioning information 51 struct prediction { 52 ptag me; // the id of the current prediction (the one being memoized) 53 size_t cnt; // how many variables are we conditioning on? 54 ptag* tags; // which variables are they? 55 action* acts; // and which actions were taken at each? 56 uint32_t hash; // a hash of the above 57 }; 58 59 // parameters for auto-conditioning 60 struct auto_condition_settings { 61 size_t max_bias_ngram_length; // add a "bias" feature for each ngram up to and including this length. eg., if it's 1, then you get a single feature for each conditional 62 size_t max_quad_ngram_length; // add bias *times* input features for each ngram up to and including this length 63 float feature_value; // how much weight should the conditional features get? 64 }; 65 66 typedef v_array<action> action_prefix; 67 68 struct search_private { 69 vw* all; 70 71 bool auto_condition_features; // do you want us to automatically add conditioning features? 72 bool auto_hamming_loss; // if you're just optimizing hamming loss, we can do it for you! 73 bool examples_dont_change; // set to true if you don't do any internal example munging 74 bool is_ldf; // user declared ldf 75 76 v_array<int32_t> neighbor_features; // ugly encoding of neighbor feature requirements 77 auto_condition_settings acset; // settings for auto-conditioning 78 size_t history_length; // value of --search_history_length, used by some tasks, default 1 79 80 size_t A; // total number of actions, [1..A]; 0 means ldf 81 size_t num_learners; // total number of learners; 82 bool cb_learner; // do contextual bandit learning on action (was "! rollout_all_actions" which was confusing) 83 SearchState state; // current state of learning 84 size_t learn_learner_id; // we allow user to use different learners for different states 85 int mix_per_roll_policy; // for MIX_PER_ROLL, we need to choose a policy to use; this is where it's stored (-2 means "not selected yet") 86 bool no_caching; // turn off caching 87 size_t rollout_num_steps; // how many calls of "loss" before we stop really predicting on rollouts and switch to oracle (0 means "infinite") 88 bool (*label_is_test)(void*); // tell me if the label data from an example is test 89 90 size_t t; // current search step 91 size_t T; // length of root trajectory 92 v_array<example> learn_ec_copy;// copy of example(s) at learn_t 93 example* learn_ec_ref; // reference to example at learn_t, when there's no example munging 94 size_t learn_ec_ref_cnt; // how many are there (for LDF mode only; otherwise 1) 95 v_array<ptag> learn_condition_on; // a copy of the tags used for conditioning at the training position 96 v_array<action> learn_condition_on_act;// the actions taken 97 v_array<char> learn_condition_on_names;// the names of the actions 98 v_array<action> learn_allowed_actions; // which actions were allowed at training time? 99 v_array<action> ptag_to_action;// tag to action mapping for conditioning 100 vector<action> test_action_sequence; // if test-mode was run, what was the corresponding action sequence; it's a vector cuz we might expose it to the library 101 action learn_oracle_action; // store an oracle action for debugging purposes 102 103 polylabel* allowed_actions_cache; 104 105 size_t loss_declared_cnt; // how many times did run declare any loss (implicitly or explicitly)? 106 v_array<action> train_trajectory; // the training trajectory 107 v_array<action> current_trajectory; // the current trajectory; only used in beam search mode 108 size_t learn_t; // what time step are we learning on? 109 size_t learn_a_idx; // what action index are we trying? 110 bool done_with_all_actions; // set to true when there are no more learn_a_idx to go 111 112 float test_loss; // loss incurred when run INIT_TEST 113 float learn_loss; // loss incurred when run LEARN 114 float train_loss; // loss incurred when run INIT_TRAIN 115 116 bool last_example_was_newline; // used so we know when a block of examples has passed 117 bool hit_new_pass; // have we hit a new pass? 118 119 // if we're printing to stderr we need to remember if we've printed the header yet 120 // (i.e., we do this if we're driving) 121 bool printed_output_header; 122 123 // various strings for different search states 124 bool should_produce_string; 125 stringstream *pred_string; 126 stringstream *truth_string; 127 stringstream *bad_string_stream; 128 129 // parameters controlling interpolation 130 float beta; // interpolation rate 131 float alpha; // parameter used to adapt beta for dagger (see above comment), should be in (0,1) 132 133 RollMethod rollout_method; // 0=policy, 1=oracle, 2=mix_per_state, 3=mix_per_roll 134 RollMethod rollin_method; 135 float subsample_timesteps; // train at every time step or just a (random) subset? 136 137 bool allow_current_policy; // should the current policy be used for training? true for dagger 138 bool adaptive_beta; // used to implement dagger-like algorithms. if true, beta = 1-(1-alpha)^n after n updates, and policy is mixed with oracle as \pi' = (1-beta)\pi^* + beta \pi 139 size_t passes_per_policy; // if we're not in dagger-mode, then we need to know how many passes to train a policy 140 141 uint32_t current_policy; // what policy are we training right now? 142 143 // various statistics for reporting 144 size_t num_features; 145 uint32_t total_number_of_policies; 146 size_t read_example_last_id; 147 size_t passes_since_new_policy; 148 size_t read_example_last_pass; 149 size_t total_examples_generated; 150 size_t total_predictions_made; 151 size_t total_cache_hits; 152 153 vector<example*> ec_seq; // the collected examples 154 v_hashmap<unsigned char*, action> cache_hash_map; 155 156 // for foreach_feature temporary storage for conditioning 157 uint32_t dat_new_feature_idx; 158 example* dat_new_feature_ec; 159 stringstream dat_new_feature_audit_ss; 160 size_t dat_new_feature_namespace; 161 string* dat_new_feature_feature_space; 162 float dat_new_feature_value; 163 164 // to reduce memory allocation 165 string rawOutputString; 166 stringstream* rawOutputStringStream; 167 CS::label ldf_test_label; 168 v_array<action> condition_on_actions; 169 v_array< pair<size_t,size_t> > timesteps; 170 v_array<float> learn_losses; 171 172 LEARNER::base_learner* base_learner; 173 clock_t start_clock_time; 174 175 example*empty_example; 176 177 Beam::beam< action_prefix > *beam; 178 size_t kbest; // size of kbest list; 1 just means 1best 179 float beam_initial_cost; // when we're doing a subsequent run, how much do we initially pay? 180 action_prefix beam_actions; // on non-initial beam runs, what prefix of actions should we take? 181 float beam_total_cost; 182 183 search_task* task; // your task! 184 }; 185 186 string audit_feature_space("conditional"); 187 uint32_t conditional_constant = 8290743; 188 random_policy(search_private & priv,bool allow_current,bool allow_optimal,bool advance_prng=true)189 int random_policy(search_private& priv, bool allow_current, bool allow_optimal, bool advance_prng=true) { 190 if (priv.beta >= 1) { 191 if (allow_current) return (int)priv.current_policy; 192 if (priv.current_policy > 0) return (((int)priv.current_policy)-1); 193 if (allow_optimal) return -1; 194 std::cerr << "internal error (bug): no valid policies to choose from! defaulting to current" << std::endl; 195 return (int)priv.current_policy; 196 } 197 198 int num_valid_policies = (int)priv.current_policy + allow_optimal + allow_current; 199 int pid = -1; 200 201 if (num_valid_policies == 0) { 202 std::cerr << "internal error (bug): no valid policies to choose from! defaulting to current" << std::endl; 203 return (int)priv.current_policy; 204 } else if (num_valid_policies == 1) 205 pid = 0; 206 else if (num_valid_policies == 2) 207 pid = (advance_prng ? frand48() : frand48_noadvance()) >= priv.beta; 208 else { 209 // SPEEDUP this up in the case that beta is small! 210 float r = (advance_prng ? frand48() : frand48_noadvance()); 211 pid = 0; 212 213 if (r > priv.beta) { 214 r -= priv.beta; 215 while ((r > 0) && (pid < num_valid_policies-1)) { 216 pid ++; 217 r -= priv.beta * powf(1.f - priv.beta, (float)pid); 218 } 219 } 220 } 221 // figure out which policy pid refers to 222 if (allow_optimal && (pid == num_valid_policies-1)) 223 return -1; // this is the optimal policy 224 225 pid = (int)priv.current_policy - pid; 226 if (!allow_current) 227 pid--; 228 229 return pid; 230 } 231 select_learner(search_private & priv,int policy,size_t learner_id)232 int select_learner(search_private& priv, int policy, size_t learner_id) { 233 if (policy<0) return policy; // optimal policy 234 else return (int) (policy*priv.num_learners+learner_id); 235 } 236 237 should_print_update(vw & all,bool hit_new_pass=false)238 bool should_print_update(vw& all, bool hit_new_pass=false) { 239 //uncomment to print out final loss after all examples processed 240 //commented for now so that outputs matches make test 241 //if( parser_done(all.p)) return true; 242 243 if (PRINT_UPDATE_EVERY_EXAMPLE) return true; 244 if (PRINT_UPDATE_EVERY_PASS && hit_new_pass) return true; 245 return (all.sd->weighted_examples >= all.sd->dump_interval) && !all.quiet && !all.bfgs; 246 } 247 248 might_print_update(vw & all)249 bool might_print_update(vw& all) { 250 // basically do should_print_update but check me and the next 251 // example because of off-by-ones 252 253 if (PRINT_UPDATE_EVERY_EXAMPLE) return true; 254 if (PRINT_UPDATE_EVERY_PASS) return true; // SPEEDUP: make this better 255 return (all.sd->weighted_examples + 1. >= all.sd->dump_interval) && !all.quiet && !all.bfgs; 256 } 257 must_run_test(vw & all,vector<example * > ec,bool is_test_ex)258 bool must_run_test(vw&all, vector<example*>ec, bool is_test_ex) { 259 return 260 (all.final_prediction_sink.size() > 0) || // if we have to produce output, we need to run this 261 might_print_update(all) || // if we have to print and update to stderr 262 (all.raw_prediction > 0) || // we need raw predictions 263 // or: 264 // it's not quiet AND 265 // current_pass == 0 266 // OR holdout is off 267 // OR it's a test example 268 ( // (! all.quiet) && // had to disable this because of library mode! 269 (! is_test_ex) && 270 ( all.holdout_set_off || // no holdout 271 ec[0]->test_only || 272 (all.current_pass == 0) // we need error rates for progressive cost 273 ) ) 274 ; 275 } 276 clear_seq(vw & all,search_private & priv)277 void clear_seq(vw&all, search_private& priv) { 278 if (priv.ec_seq.size() > 0) 279 for (size_t i=0; i < priv.ec_seq.size(); i++) 280 VW::finish_example(all, priv.ec_seq[i]); 281 priv.ec_seq.clear(); 282 } 283 safediv(float a,float b)284 float safediv(float a,float b) { if (b == 0.f) return 0.f; else return (a/b); } 285 to_short_string(string in,size_t max_len,char * out)286 void to_short_string(string in, size_t max_len, char*out) { 287 for (size_t i=0; i<max_len; i++) 288 out[i] = ((i >= in.length()) || (in[i] == '\n') || (in[i] == '\t')) ? ' ' : in[i]; 289 290 if (in.length() > max_len) { 291 out[max_len-2] = '.'; 292 out[max_len-1] = '.'; 293 } 294 out[max_len] = 0; 295 } 296 number_to_natural(size_t big,char * c)297 void number_to_natural(size_t big, char* c) { 298 if (big > 9999999999) sprintf(c, "%dg", (int)(big / 1000000000)); 299 else if (big > 9999999) sprintf(c, "%dm", (int)(big / 1000000)); 300 else if (big > 9999) sprintf(c, "%dk", (int)(big / 1000)); 301 else sprintf(c, "%d", (int)(big)); 302 } 303 print_update(search_private & priv)304 void print_update(search_private& priv) { 305 vw& all = *priv.all; 306 if (!priv.printed_output_header && !all.quiet) { 307 const char * header_fmt = "%-10s %-10s %8s%24s %22s %5s %5s %7s %7s %7s %-8s\n"; 308 fprintf(stderr, header_fmt, "average", "since", "instance", "current true", "current predicted", "cur", "cur", "predic", "cache", "examples", ""); 309 fprintf(stderr, header_fmt, "loss", "last", "counter", "output prefix", "output prefix", "pass", "pol", "made", "hits", "gener", "beta"); 310 std::cerr.precision(5); 311 priv.printed_output_header = true; 312 } 313 314 if (!should_print_update(all, priv.hit_new_pass)) 315 return; 316 317 char true_label[21]; 318 char pred_label[21]; 319 to_short_string(priv.truth_string->str(), 20, true_label); 320 to_short_string(priv.pred_string->str() , 20, pred_label); 321 322 float avg_loss = 0.; 323 float avg_loss_since = 0.; 324 bool use_heldout_loss = (!all.holdout_set_off && all.current_pass >= 1) && (all.sd->weighted_holdout_examples > 0); 325 if (use_heldout_loss) { 326 avg_loss = safediv((float)all.sd->holdout_sum_loss, (float)all.sd->weighted_holdout_examples); 327 avg_loss_since = safediv((float)all.sd->holdout_sum_loss_since_last_dump, (float)all.sd->weighted_holdout_examples_since_last_dump); 328 329 all.sd->weighted_holdout_examples_since_last_dump = 0; 330 all.sd->holdout_sum_loss_since_last_dump = 0.0; 331 } else { 332 avg_loss = safediv((float)all.sd->sum_loss, (float)all.sd->weighted_examples); 333 avg_loss_since = safediv((float)all.sd->sum_loss_since_last_dump, (float) (all.sd->weighted_examples - all.sd->old_weighted_examples)); 334 } 335 336 char inst_cntr[9]; number_to_natural(all.sd->example_number, inst_cntr); 337 char total_pred[8]; number_to_natural(priv.total_predictions_made, total_pred); 338 char total_cach[8]; number_to_natural(priv.total_cache_hits, total_cach); 339 char total_exge[8]; number_to_natural(priv.total_examples_generated, total_exge); 340 341 fprintf(stderr, "%-10.6f %-10.6f %8s [%s] [%s] %5d %5d %7s %7s %7s %-8f", 342 avg_loss, 343 avg_loss_since, 344 inst_cntr, 345 true_label, 346 pred_label, 347 (int)priv.read_example_last_pass, 348 (int)priv.current_policy, 349 total_pred, 350 total_cach, 351 total_exge, 352 priv.beta); 353 354 if (PRINT_CLOCK_TIME) { 355 size_t num_sec = (size_t)(((float)(clock() - priv.start_clock_time)) / CLOCKS_PER_SEC); 356 fprintf(stderr, " %15lusec", num_sec); 357 } 358 359 if (use_heldout_loss) 360 fprintf(stderr, " h"); 361 362 fprintf(stderr, "\n"); 363 fflush(stderr); 364 all.sd->update_dump_interval(all.progress_add, all.progress_arg); 365 } 366 add_new_feature(search_private & priv,float val,uint32_t idx)367 void add_new_feature(search_private& priv, float val, uint32_t idx) { 368 size_t mask = priv.all->reg.weight_mask; 369 size_t ss = priv.all->reg.stride_shift; 370 size_t idx2 = ((idx & mask) >> ss) & mask; 371 feature f = { val * priv.dat_new_feature_value, 372 (uint32_t) (((priv.dat_new_feature_idx + idx2) << ss) ) }; 373 priv.dat_new_feature_ec->atomics[priv.dat_new_feature_namespace].push_back(f); 374 priv.dat_new_feature_ec->sum_feat_sq[priv.dat_new_feature_namespace] += f.x * f.x; 375 if (priv.all->audit) { 376 audit_data a = { NULL, NULL, f.weight_index, f.x, true }; 377 a.space = calloc_or_die<char>(priv.dat_new_feature_feature_space->length()+1); 378 a.feature = calloc_or_die<char>(priv.dat_new_feature_audit_ss.str().length() + 32); 379 strcpy(a.space, priv.dat_new_feature_feature_space->c_str()); 380 int num = sprintf(a.feature, "fid=%lu_", (idx & mask) >> ss); 381 strcpy(a.feature+num, priv.dat_new_feature_audit_ss.str().c_str()); 382 priv.dat_new_feature_ec->audit_features[priv.dat_new_feature_namespace].push_back(a); 383 } 384 } 385 del_features_in_top_namespace(search_private & priv,example & ec,size_t ns)386 void del_features_in_top_namespace(search_private& priv, example& ec, size_t ns) { 387 if ((ec.indices.size() == 0) || (ec.indices.last() != ns)) { 388 std::cerr << "internal error (bug): expecting top namespace to be '" << ns << "' but it was "; 389 if (ec.indices.size() == 0) std::cerr << "empty"; 390 else std::cerr << (size_t)ec.indices.last(); 391 std::cerr << endl; 392 throw exception(); 393 } 394 ec.num_features -= ec.atomics[ns].size(); 395 ec.total_sum_feat_sq -= ec.sum_feat_sq[ns]; 396 ec.sum_feat_sq[ns] = 0; 397 ec.indices.decr(); 398 ec.atomics[ns].erase(); 399 if (priv.all->audit) { 400 for (size_t i=0; i<ec.audit_features[ns].size(); i++) 401 if (ec.audit_features[ns][i].alloced) { 402 free(ec.audit_features[ns][i].space); 403 free(ec.audit_features[ns][i].feature); 404 } 405 ec.audit_features[ns].erase(); 406 } 407 } 408 add_neighbor_features(search_private & priv)409 void add_neighbor_features(search_private& priv) { 410 vw& all = *priv.all; 411 if (priv.neighbor_features.size() == 0) return; 412 413 for (size_t n=0; n<priv.ec_seq.size(); n++) { // iterate over every example in the sequence 414 example& me = *priv.ec_seq[n]; 415 for (size_t n_id=0; n_id < priv.neighbor_features.size(); n_id++) { 416 int32_t offset = priv.neighbor_features[n_id] >> 24; 417 size_t ns = priv.neighbor_features[n_id] & 0xFF; 418 419 priv.dat_new_feature_ec = &me; 420 priv.dat_new_feature_value = 1.; 421 priv.dat_new_feature_idx = priv.neighbor_features[n_id] * 13748127; 422 priv.dat_new_feature_namespace = neighbor_namespace; 423 if (priv.all->audit) { 424 priv.dat_new_feature_feature_space = &neighbor_feature_space; 425 priv.dat_new_feature_audit_ss.str(""); 426 priv.dat_new_feature_audit_ss << '@' << ((offset > 0) ? '+' : '-') << (char)(abs(offset) + '0'); 427 if (ns != ' ') priv.dat_new_feature_audit_ss << (char)ns; 428 } 429 430 //cerr << "n=" << n << " offset=" << offset << endl; 431 if ((offset < 0) && (n < (uint32_t)(-offset))) // add <s> feature 432 add_new_feature(priv, 1., 925871901 << priv.all->reg.stride_shift); 433 else if (n + offset >= priv.ec_seq.size()) // add </s> feature 434 add_new_feature(priv, 1., 3824917 << priv.all->reg.stride_shift); 435 else { // this is actually a neighbor 436 example& other = *priv.ec_seq[n + offset]; 437 GD::foreach_feature<search_private,add_new_feature>(all.reg.weight_vector, all.reg.weight_mask, other.atomics[ns].begin, other.atomics[ns].end, priv, me.ft_offset); 438 } 439 } 440 441 size_t sz = me.atomics[neighbor_namespace].size(); 442 if ((sz > 0) && (me.sum_feat_sq[neighbor_namespace] > 0.)) { 443 me.indices.push_back(neighbor_namespace); 444 me.total_sum_feat_sq += me.sum_feat_sq[neighbor_namespace]; 445 me.num_features += sz; 446 } else { 447 me.atomics[neighbor_namespace].erase(); 448 if (priv.all->audit) me.audit_features[neighbor_namespace].erase(); 449 } 450 } 451 } 452 del_neighbor_features(search_private & priv)453 void del_neighbor_features(search_private& priv) { 454 if (priv.neighbor_features.size() == 0) return; 455 for (size_t n=0; n<priv.ec_seq.size(); n++) 456 del_features_in_top_namespace(priv, *priv.ec_seq[n], neighbor_namespace); 457 } 458 reset_search_structure(search_private & priv)459 void reset_search_structure(search_private& priv) { 460 // NOTE: make sure do NOT reset priv.learn_a_idx 461 priv.t = 0; 462 priv.loss_declared_cnt = 0; 463 priv.done_with_all_actions = false; 464 priv.test_loss = 0.; 465 priv.learn_loss = 0.; 466 priv.train_loss = 0.; 467 priv.num_features = 0; 468 priv.should_produce_string = false; 469 priv.mix_per_roll_policy = -2; 470 if (priv.adaptive_beta) { 471 float x = - log1pf(- priv.alpha) * (float)priv.total_examples_generated; 472 static const float log_of_2 = (float)0.6931471805599453; 473 priv.beta = (x <= log_of_2) ? -expm1f(-x) : (1-expf(-x)); // numerical stability 474 //float priv_beta = 1.f - powf(1.f - priv.alpha, (float)priv.total_examples_generated); 475 //assert( fabs(priv_beta - priv.beta) < 1e-2 ); 476 if (priv.beta > 1) priv.beta = 1; 477 } 478 priv.ptag_to_action.erase(); 479 if (priv.beam) 480 priv.current_trajectory.erase(); 481 482 if (! priv.cb_learner) { // was: if rollout_all_actions 483 uint32_t seed = (uint32_t)(priv.read_example_last_id * 147483 + 4831921) * 2147483647; 484 msrand48(seed); 485 } 486 } 487 search_declare_loss(search_private & priv,float loss)488 void search_declare_loss(search_private& priv, float loss) { 489 priv.loss_declared_cnt++; 490 switch (priv.state) { 491 case INIT_TEST: priv.test_loss += loss; break; 492 case INIT_TRAIN: priv.train_loss += loss; break; 493 case LEARN: 494 if ((priv.rollout_num_steps == 0) || (priv.loss_declared_cnt <= priv.rollout_num_steps)) 495 priv.learn_loss += loss; 496 break; 497 default: break; // get rid of the warning about missing cases (danger!) 498 } 499 } 500 random(size_t max)501 size_t random(size_t max) { return (size_t)(frand48() * (float)max); } array_contains(T target,const T * A,size_t n)502 template<class T> bool array_contains(T target, const T*A, size_t n) { 503 if (A == NULL) return false; 504 for (size_t i=0; i<n; i++) 505 if (A[i] == target) return true; 506 return false; 507 } 508 choose_oracle_action(search_private & priv,size_t ec_cnt,const action * oracle_actions,size_t oracle_actions_cnt,const action * allowed_actions,size_t allowed_actions_cnt,bool add_alternatives_to_beam)509 action choose_oracle_action(search_private& priv, size_t ec_cnt, const action* oracle_actions, size_t oracle_actions_cnt, const action* allowed_actions, size_t allowed_actions_cnt, bool add_alternatives_to_beam) { 510 action ret = ( oracle_actions_cnt > 0) ? oracle_actions[random(oracle_actions_cnt )] : 511 (allowed_actions_cnt > 0) ? allowed_actions[random(allowed_actions_cnt)] : 512 priv.is_ldf ? (action)random(ec_cnt) : 513 (action)(1 + random(ec_cnt)); 514 cdbg << "choose_oracle_action from oracle_actions = ["; for (size_t i=0; i<oracle_actions_cnt; i++) cdbg << " " << oracle_actions[i]; cdbg << " ], ret=" << ret << endl; 515 if (add_alternatives_to_beam) { 516 // first, insert all the oracle actions (other than ret) 517 size_t new_len = priv.current_trajectory.size() + 1; 518 if (oracle_actions_cnt > 1) 519 for (size_t i=0; i<oracle_actions_cnt; i++) 520 if (oracle_actions[i] != ret) { 521 float delta_cost = priv.beam_initial_cost + 1e-6f; 522 action_prefix* px = new v_array<action>; 523 px->resize(new_len+1); 524 px->end = px->begin + new_len + 1; 525 memcpy(px->begin, priv.current_trajectory.begin, sizeof(action) * (new_len-1)); 526 px->begin[new_len-1] = oracle_actions[i]; 527 *((float*)(px->begin+new_len)) = delta_cost; 528 uint32_t px_hash = uniform_hash(px->begin, sizeof(action) * new_len, 3419); 529 //cerr << "insertingA" << endl; 530 if (! priv.beam->insert(px, delta_cost, px_hash)) { 531 px->delete_v(); // SPEEDUP: could be more efficient by reusing for next action 532 delete px; 533 } 534 } 535 // now add all the non-oracle actions (other than ret) 536 size_t top = (allowed_actions_cnt > 0) ? allowed_actions_cnt : ec_cnt; 537 for (size_t i = 0; i < top; i++) { 538 size_t a = (allowed_actions_cnt > 0) ? allowed_actions[i] : i; 539 if (a == ret) continue; 540 if (array_contains<action>((action)a, oracle_actions, oracle_actions_cnt)) continue; 541 float delta_cost = priv.beam_initial_cost + 1.f + 1e-6f; // TODO: why is this the right cost? 542 action_prefix* px = new v_array<action>; 543 px->resize(new_len + 1); 544 px->end = px->begin + new_len + 1; 545 memcpy(px->begin, priv.current_trajectory.begin, sizeof(action) * (new_len-1)); 546 px->begin[new_len-1] = (action)a; 547 *((float*)(px->begin+new_len)) = delta_cost; 548 uint32_t px_hash = uniform_hash(px->begin, sizeof(action) * new_len, 3419); 549 //cerr << "insertingB " << i << " " << allowed_actions_cnt << " " << ec_cnt << " " << top << endl; 550 if (! priv.beam->insert(px, delta_cost, px_hash)) { 551 px->delete_v(); // SPEEDUP: could be more efficient by reusing for next action 552 delete px; 553 } 554 } 555 } 556 return ret; 557 } 558 add_example_conditioning(search_private & priv,example & ec,const ptag * condition_on,size_t condition_on_cnt,const char * condition_on_names,const action * condition_on_actions)559 void add_example_conditioning(search_private& priv, example& ec, const ptag* condition_on, size_t condition_on_cnt, const char* condition_on_names, const action* condition_on_actions) { 560 if (condition_on_cnt == 0) return; 561 562 uint32_t extra_offset=0; 563 if (priv.is_ldf) 564 if (ec.l.cs.costs.size() > 0) 565 extra_offset = 3849017 * ec.l.cs.costs[0].class_index; 566 567 size_t I = condition_on_cnt; 568 size_t N = max(priv.acset.max_bias_ngram_length, priv.acset.max_quad_ngram_length); 569 for (size_t i=0; i<I; i++) { // position in conditioning 570 uint32_t fid = 71933 + 8491087 * extra_offset; 571 if (priv.all->audit) { 572 priv.dat_new_feature_audit_ss.str(""); 573 priv.dat_new_feature_audit_ss.clear(); 574 priv.dat_new_feature_feature_space = &condition_feature_space; 575 } 576 577 for (size_t n=0; n<N; n++) { // length of ngram 578 if (i + n >= I) break; // no more ngrams 579 // we're going to add features for the ngram condition_on_actions[i .. i+N] 580 char name = condition_on_names[i+n]; 581 fid = fid * 328901 + 71933 * ((condition_on_actions[i+n] + 349101) * (name + 38490137)); 582 583 priv.dat_new_feature_ec = &ec; 584 priv.dat_new_feature_idx = fid * quadratic_constant; 585 priv.dat_new_feature_namespace = conditioning_namespace; 586 priv.dat_new_feature_value = priv.acset.feature_value; 587 588 if (priv.all->audit) { 589 if (n > 0) priv.dat_new_feature_audit_ss << ','; 590 if ((33 <= name) && (name <= 126)) priv.dat_new_feature_audit_ss << name; 591 else priv.dat_new_feature_audit_ss << '#' << (int)name; 592 priv.dat_new_feature_audit_ss << '=' << condition_on_actions[i+n]; 593 } 594 595 // add the single bias feature 596 if (n < priv.acset.max_bias_ngram_length) 597 add_new_feature(priv, 1., 4398201 << priv.all->reg.stride_shift); 598 599 // add the quadratic features 600 if (n < priv.acset.max_quad_ngram_length) 601 GD::foreach_feature<search_private,uint32_t,add_new_feature>(*priv.all, ec, priv); 602 } 603 } 604 605 size_t sz = ec.atomics[conditioning_namespace].size(); 606 if ((sz > 0) && (ec.sum_feat_sq[conditioning_namespace] > 0.)) { 607 ec.indices.push_back(conditioning_namespace); 608 ec.total_sum_feat_sq += ec.sum_feat_sq[conditioning_namespace]; 609 ec.num_features += sz; 610 } else { 611 ec.atomics[conditioning_namespace].erase(); 612 if (priv.all->audit) ec.audit_features[conditioning_namespace].erase(); 613 } 614 } 615 del_example_conditioning(search_private & priv,example & ec)616 void del_example_conditioning(search_private& priv, example& ec) { 617 if ((ec.indices.size() > 0) && (ec.indices.last() == conditioning_namespace)) 618 del_features_in_top_namespace(priv, ec, conditioning_namespace); 619 } 620 cs_get_costs_size(bool isCB,polylabel & ld)621 size_t cs_get_costs_size(bool isCB, polylabel& ld) { 622 return isCB ? ld.cb.costs.size() 623 : ld.cs.costs.size(); 624 } 625 cs_get_cost_index(bool isCB,polylabel & ld,size_t k)626 uint32_t cs_get_cost_index(bool isCB, polylabel& ld, size_t k) { 627 return isCB ? ld.cb.costs[k].action 628 : ld.cs.costs[k].class_index; 629 } 630 cs_get_cost_partial_prediction(bool isCB,polylabel & ld,size_t k)631 float cs_get_cost_partial_prediction(bool isCB, polylabel& ld, size_t k) { 632 return isCB ? ld.cb.costs[k].partial_prediction 633 : ld.cs.costs[k].partial_prediction; 634 } 635 cs_costs_erase(bool isCB,polylabel & ld)636 void cs_costs_erase(bool isCB, polylabel& ld) { 637 if (isCB) ld.cb.costs.erase(); 638 else ld.cs.costs.erase(); 639 } 640 cs_costs_resize(bool isCB,polylabel & ld,size_t new_size)641 void cs_costs_resize(bool isCB, polylabel& ld, size_t new_size) { 642 if (isCB) ld.cb.costs.resize(new_size); 643 else ld.cs.costs.resize(new_size); 644 } 645 cs_cost_push_back(bool isCB,polylabel & ld,uint32_t index,float value)646 void cs_cost_push_back(bool isCB, polylabel& ld, uint32_t index, float value) { 647 if (isCB) { CB::cb_class cost = { value, index, 0., 0. }; ld.cb.costs.push_back(cost); } 648 else { CS::wclass cost = { value, index, 0., 0. }; ld.cs.costs.push_back(cost); } 649 } 650 allowed_actions_to_ld(search_private & priv,size_t ec_cnt,const action * allowed_actions,size_t allowed_actions_cnt)651 polylabel& allowed_actions_to_ld(search_private& priv, size_t ec_cnt, const action* allowed_actions, size_t allowed_actions_cnt) { 652 bool isCB = priv.cb_learner; 653 polylabel& ld = *priv.allowed_actions_cache; 654 uint32_t num_costs = (uint32_t)cs_get_costs_size(isCB, ld); 655 656 if (priv.is_ldf) { // LDF version easier 657 if (num_costs > ec_cnt) 658 cs_costs_resize(isCB, ld, ec_cnt); 659 else if (num_costs < ec_cnt) 660 for (action k = num_costs; k < ec_cnt; k++) 661 cs_cost_push_back(isCB, ld, k, FLT_MAX); 662 663 } else { // non-LDF version 664 if ((allowed_actions == NULL) || (allowed_actions_cnt == 0)) { // any action is allowed 665 if (num_costs != priv.A) { // if there are already A-many actions, they must be the right ones, unless the user did something stupid like putting duplicate allowed_actions... 666 cs_costs_erase(isCB, ld); 667 for (action k = 0; k < priv.A; k++) 668 cs_cost_push_back(isCB, ld, k+1, FLT_MAX); //+1 because MC is 1-based 669 } 670 } else { // we need to peek at allowed_actions 671 cs_costs_erase(isCB, ld); 672 for (size_t i = 0; i < allowed_actions_cnt; i++) 673 cs_cost_push_back(isCB, ld, allowed_actions[i], FLT_MAX); 674 } 675 } 676 677 return ld; 678 } 679 allowed_actions_to_losses(search_private & priv,size_t ec_cnt,const action * allowed_actions,size_t allowed_actions_cnt,const action * oracle_actions,size_t oracle_actions_cnt,v_array<float> & losses)680 void allowed_actions_to_losses(search_private& priv, size_t ec_cnt, const action* allowed_actions, size_t allowed_actions_cnt, const action* oracle_actions, size_t oracle_actions_cnt, v_array<float>& losses) { 681 if (priv.is_ldf) // LDF version easier 682 for (action k=0; k<ec_cnt; k++) 683 losses.push_back( array_contains<action>(k, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f ); 684 else { // non-LDF 685 if ((allowed_actions == NULL) || (allowed_actions_cnt == 0)) // any action is allowed 686 for (action k=1; k<=priv.A; k++) 687 losses.push_back( array_contains<action>(k, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f ); 688 else 689 for (size_t i=0; i<allowed_actions_cnt; i++) { 690 action k = allowed_actions[i]; 691 losses.push_back( array_contains<action>(k, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f ); 692 } 693 } 694 } 695 single_prediction_notLDF(search_private & priv,example & ec,int policy,const action * allowed_actions,size_t allowed_actions_cnt)696 action single_prediction_notLDF(search_private& priv, example& ec, int policy, const action* allowed_actions, size_t allowed_actions_cnt) { 697 vw& all = *priv.all; 698 polylabel old_label = ec.l; 699 ec.l = allowed_actions_to_ld(priv, 1, allowed_actions, allowed_actions_cnt); 700 701 priv.base_learner->predict(ec, policy); 702 uint32_t act = ec.pred.multiclass; 703 704 // in beam search mode, go through alternatives and add them as back-ups 705 if (priv.beam) { 706 float act_cost = 0; 707 size_t K = cs_get_costs_size(priv.cb_learner, ec.l); 708 for (size_t k = 0; k < K; k++) 709 if (cs_get_cost_index(priv.cb_learner, ec.l, k) == act) { 710 act_cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k); 711 break; 712 } 713 714 priv.beam_total_cost += act_cost; 715 size_t new_len = priv.current_trajectory.size() + 1; 716 for (size_t k = 0; k < K; k++) { 717 action k_act = cs_get_cost_index(priv.cb_learner, ec.l, k); 718 if (k_act == act) continue; // skip the taken action 719 // TODO: delta_cost is correct for prioritizing, but not for full path cost -- for that, we cannot be subtracting off act_cost 720 float delta_cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k) - act_cost + priv.beam_initial_cost; // TODO: is delta_cost the right cost? 721 // construct the action prefix 722 action_prefix* px = new v_array<action>; 723 *px = v_init<action>(); 724 px->resize(new_len + 1); 725 px->end = px->begin + new_len + 1; 726 memcpy(px->begin, priv.current_trajectory.begin, sizeof(action) * (new_len-1)); 727 px->begin[new_len-1] = k_act; 728 *((float*)(px->begin+new_len)) = delta_cost + act_cost; 729 uint32_t px_hash = uniform_hash(px->begin, sizeof(action) * new_len, 3419); 730 cdbg << "inserting delta_cost=" << delta_cost << " total_cost=" << *((float*)(px->begin+new_len)) << " seq="; 731 for (size_t ii=0; ii<new_len; ii++) cdbg << px->begin[ii] << ' '; cdbg << endl; 732 if (! priv.beam->insert(px, delta_cost, px_hash)) { 733 px->delete_v(); // SPEEDUP: could be more efficient by reusing for next action 734 delete px; 735 } 736 } 737 } 738 739 // generate raw predictions if necessary 740 if ((priv.state == INIT_TEST) && (all.raw_prediction > 0)) { 741 priv.rawOutputStringStream->str(""); 742 for (size_t k = 0; k < cs_get_costs_size(priv.cb_learner, ec.l); k++) { 743 if (k > 0) (*priv.rawOutputStringStream) << ' '; 744 (*priv.rawOutputStringStream) << cs_get_cost_index(priv.cb_learner, ec.l, k) << ':' << cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k); 745 } 746 all.print_text(all.raw_prediction, priv.rawOutputStringStream->str(), ec.tag); 747 } 748 749 ec.l = old_label; 750 751 priv.total_predictions_made++; 752 priv.num_features += ec.num_features; 753 754 return act; 755 } 756 single_prediction_LDF(search_private & priv,example * ecs,size_t ec_cnt,int policy)757 action single_prediction_LDF(search_private& priv, example* ecs, size_t ec_cnt, int policy) { 758 CS::cs_label.default_label(&priv.ldf_test_label); 759 CS::wclass wc = { 0., 1, 0., 0. }; 760 priv.ldf_test_label.costs.push_back(wc); 761 762 // keep track of best (aka chosen) action 763 float best_prediction = 0.; 764 action best_action = 0; 765 766 size_t start_K = (priv.is_ldf && LabelDict::ec_is_example_header(ecs[0])) ? 1 : 0; 767 768 for (action a= (uint32_t)start_K; a<ec_cnt; a++) { 769 cdbg << "== single_prediction_LDF a=" << a << "==" << endl; 770 if (start_K > 0) 771 LabelDict::add_example_namespaces_from_example(ecs[a], ecs[0]); 772 773 polylabel old_label = ecs[a].l; 774 ecs[a].l.cs = priv.ldf_test_label; 775 priv.base_learner->predict(ecs[a], policy); 776 777 priv.empty_example->in_use = true; 778 priv.base_learner->predict(*priv.empty_example); 779 780 //cerr << "partial_prediction[" << a << "] = " << ecs[a].partial_prediction << endl; 781 782 if ((a == start_K) || (ecs[a].partial_prediction < best_prediction)) { 783 best_prediction = ecs[a].partial_prediction; 784 best_action = a; 785 } 786 787 priv.num_features += ecs[a].num_features; 788 ecs[a].l = old_label; 789 if (start_K > 0) 790 LabelDict::del_example_namespaces_from_example(ecs[a], ecs[0]); 791 } 792 793 if (priv.beam) { 794 priv.beam_total_cost += best_prediction; 795 size_t new_len = priv.current_trajectory.size() + 1; 796 for (size_t k=start_K; k<ec_cnt; k++) { 797 if (k == best_action) continue; 798 float delta_cost = ecs[k].partial_prediction - best_prediction + priv.beam_initial_cost; 799 action_prefix* px = new v_array<action>; 800 *px = v_init<action>(); 801 px->resize(new_len + 1); 802 px->end = px->begin + new_len + 1; 803 memcpy(px->begin, priv.current_trajectory.begin, sizeof(action) * (new_len-1)); 804 px->begin[new_len-1] = (uint32_t)k; // TODO: k or ld[k]? 805 *((float*)(px->begin+new_len)) = delta_cost + best_prediction; 806 uint32_t px_hash = uniform_hash(px->begin, sizeof(action) * new_len, 3419); 807 if (! priv.beam->insert(px, delta_cost, px_hash)) { 808 px->delete_v(); // SPEEDUP: could be more efficient by reusing for next action 809 delete px; 810 } 811 } 812 } 813 814 priv.total_predictions_made++; 815 return best_action; 816 } 817 choose_policy(search_private & priv,bool advance_prng=true)818 int choose_policy(search_private& priv, bool advance_prng=true) { 819 RollMethod method = (priv.state == INIT_TEST ) ? POLICY : 820 (priv.state == LEARN ) ? priv.rollout_method : 821 (priv.state == INIT_TRAIN) ? priv.rollin_method : 822 NO_ROLLOUT; // this should never happen 823 switch (method) { 824 case POLICY: 825 return random_policy(priv, priv.allow_current_policy || (priv.state == INIT_TEST), false, advance_prng); 826 827 case ORACLE: 828 return -1; 829 830 case MIX_PER_STATE: 831 return random_policy(priv, priv.allow_current_policy, true, advance_prng); 832 833 case MIX_PER_ROLL: 834 if (priv.mix_per_roll_policy == -2) // then we have to choose one! 835 priv.mix_per_roll_policy = random_policy(priv, priv.allow_current_policy, true, advance_prng); 836 return priv.mix_per_roll_policy; 837 838 case NO_ROLLOUT: 839 default: 840 std::cerr << "internal error (bug): trying to rollin or rollout with NO_ROLLOUT" << endl; 841 throw exception(); 842 } 843 } 844 cdbg_print_array(string str,v_array<T> & A)845 template<class T> void cdbg_print_array(string str, v_array<T>& A) { cdbg << str << " = ["; for (size_t i=0; i<A.size(); i++) cdbg << " " << A[i]; cdbg << " ]" << endl; } cerr_print_array(string str,v_array<T> & A)846 template<class T> void cerr_print_array(string str, v_array<T>& A) { cerr << str << " = ["; for (size_t i=0; i<A.size(); i++) cerr << " " << A[i]; cerr << " ]" << endl; } 847 848 template<class T> ensure_size(v_array<T> & A,size_t sz)849 void ensure_size(v_array<T>& A, size_t sz) { 850 if ((size_t)(A.end_array - A.begin) < sz) 851 A.resize(sz*2+1, true); 852 A.end = A.begin + sz; 853 } 854 push_at(v_array<T> & v,T item,size_t pos)855 template<class T> void push_at(v_array<T>& v, T item, size_t pos) { 856 if (v.size() > pos) 857 v.begin[pos] = item; 858 else { 859 if (v.end_array > v.begin + pos) { 860 // there's enough memory, just not enough filler 861 v.begin[pos] = item; 862 v.end = v.begin + pos + 1; 863 } else { 864 // there's not enough memory 865 v.resize(2 * pos + 3, true); 866 v.begin[pos] = item; 867 v.end = v.begin + pos + 1; 868 } 869 } 870 } 871 record_action(search_private & priv,ptag mytag,action a)872 void record_action(search_private& priv, ptag mytag, action a) { 873 if (mytag == 0) return; 874 push_at(priv.ptag_to_action, a, mytag); 875 } 876 cached_item_equivalent(unsigned char * & A,unsigned char * & B)877 bool cached_item_equivalent(unsigned char*& A, unsigned char*& B) { 878 size_t sz_A = *A; 879 size_t sz_B = *B; 880 if (sz_A != sz_B) return false; 881 return memcmp(A, B, sz_A) == 0; 882 } 883 free_key(unsigned char * mem,action a)884 void free_key(unsigned char* mem, action a) { free(mem); } clear_cache_hash_map(search_private & priv)885 void clear_cache_hash_map(search_private& priv) { 886 priv.cache_hash_map.iter(free_key); 887 priv.cache_hash_map.clear(); 888 } 889 890 // returns true if found and do_store is false. if do_store is true, always returns true. cached_action_store_or_find(search_private & priv,ptag mytag,const ptag * condition_on,const char * condition_on_names,const action * condition_on_actions,size_t condition_on_cnt,int policy,size_t learner_id,action & a,bool do_store)891 bool cached_action_store_or_find(search_private& priv, ptag mytag, const ptag* condition_on, const char* condition_on_names, const action* condition_on_actions, size_t condition_on_cnt, int policy, size_t learner_id, action &a, bool do_store) { 892 if (priv.no_caching) return do_store; 893 if (mytag == 0) return do_store; // don't attempt to cache when tag is zero 894 895 size_t sz = sizeof(size_t) + sizeof(ptag) + sizeof(int) + sizeof(size_t) + sizeof(size_t) + condition_on_cnt * (sizeof(ptag) + sizeof(action) + sizeof(char)); 896 if (sz % 4 != 0) sz = 4 * (sz / 4 + 1); // make sure sz aligns to 4 so that uniform_hash does the right thing 897 898 unsigned char* item = calloc_or_die<unsigned char>(sz); 899 unsigned char* here = item; 900 *here = (unsigned char)sz; here += sizeof(size_t); 901 *here = mytag; here += sizeof(ptag); 902 *here = policy; here += sizeof(int); 903 *here = (unsigned char)learner_id; here += sizeof(size_t); 904 *here = (unsigned char)condition_on_cnt; here += (unsigned char)sizeof(size_t); 905 for (size_t i=0; i<condition_on_cnt; i++) { 906 *here = condition_on[i]; here += sizeof(ptag); 907 *here = condition_on_actions[i]; here += sizeof(action); 908 *here = condition_on_names[i]; here += sizeof(char); // SPEEDUP: should we align this at 4? 909 } 910 uint32_t hash = uniform_hash(item, sz, 3419); 911 912 if (do_store) { 913 priv.cache_hash_map.put(item, hash, a); 914 return true; 915 } else { // its a find 916 a = priv.cache_hash_map.get(item, hash); 917 free(item); 918 return a != (action)-1; 919 } 920 } 921 generate_training_example(search_private & priv,v_array<float> & losses,bool add_conditioning=true)922 void generate_training_example(search_private& priv, v_array<float>& losses, bool add_conditioning=true) { 923 // should we really subtract out min-loss? 924 float min_loss = FLT_MAX, max_loss = -FLT_MAX; 925 size_t num_min = 0; 926 for (size_t i=0; i<losses.size(); i++) { 927 if (losses[i] < min_loss) { min_loss = losses[i]; num_min = 1; } 928 else if (losses[i] == min_loss) num_min++; 929 if (losses[i] > max_loss) { max_loss = losses[i]; } 930 } 931 932 int learner = select_learner(priv, priv.current_policy, priv.learn_learner_id); 933 934 if (!priv.is_ldf) { // not LDF 935 // since we're not LDF, it should be the case that ec_ref_cnt == 1 936 // and learn_ec_ref[0] is a pointer to a single example 937 assert(priv.learn_ec_ref_cnt == 1); 938 assert(priv.learn_ec_ref != NULL); 939 940 polylabel labels = allowed_actions_to_ld(priv, priv.learn_ec_ref_cnt, priv.learn_allowed_actions.begin, priv.learn_allowed_actions.size()); 941 cdbg_print_array("learn_allowed_actions", priv.learn_allowed_actions); 942 //bool any_gt_1 = false; 943 for (size_t i=0; i<losses.size(); i++) { 944 losses[i] = losses[i] - min_loss; // TODO: in BEAM mode, subtracting off min_loss seems like a bad idea 945 if (priv.cb_learner) labels.cb.costs[i].cost = losses[i]; 946 else labels.cs.costs[i].x = losses[i]; 947 } 948 949 example& ec = priv.learn_ec_ref[0]; 950 polylabel old_label = ec.l; 951 ec.l = labels; 952 ec.in_use = true; 953 if (add_conditioning) add_example_conditioning(priv, ec, priv.learn_condition_on.begin, priv.learn_condition_on.size(), priv.learn_condition_on_names.begin, priv.learn_condition_on_act.begin); 954 priv.base_learner->learn(ec, learner); 955 if (add_conditioning) del_example_conditioning(priv, ec); 956 ec.l = old_label; 957 priv.total_examples_generated++; 958 } else { // is LDF 959 assert(losses.size() == priv.learn_ec_ref_cnt); 960 size_t start_K = (priv.is_ldf && LabelDict::ec_is_example_header(priv.learn_ec_ref[0])) ? 1 : 0; 961 for (action a= (uint32_t)start_K; a<priv.learn_ec_ref_cnt; a++) { 962 example& ec = priv.learn_ec_ref[a]; 963 964 CS::label& lab = ec.l.cs; 965 if (lab.costs.size() == 0) { 966 CS::wclass wc = { 0., 1, 0., 0. }; 967 lab.costs.push_back(wc); 968 } 969 lab.costs[0].x = losses[a] - min_loss; 970 //cerr << "cost[" << a << "] = " << losses[a] << " - " << min_loss << " = " << lab.costs[0].x << endl; 971 ec.in_use = true; 972 if (add_conditioning) add_example_conditioning(priv, ec, priv.learn_condition_on.begin, priv.learn_condition_on.size(), priv.learn_condition_on_names.begin, priv.learn_condition_on_act.begin); 973 priv.base_learner->learn(ec, learner); 974 cdbg << "generate_training_example called learn on action a=" << a << ", costs.size=" << lab.costs.size() << " ec=" << &ec << endl; 975 priv.total_examples_generated++; 976 } 977 priv.base_learner->learn(*priv.empty_example, learner); 978 cdbg << "generate_training_example called learn on empty_example" << endl; 979 980 for (action a= (uint32_t)start_K; a<priv.learn_ec_ref_cnt; a++) { 981 example& ec = priv.learn_ec_ref[a]; 982 if (add_conditioning) 983 del_example_conditioning(priv, ec); 984 } 985 } 986 } 987 search_predictNeedsExample(search_private & priv)988 bool search_predictNeedsExample(search_private& priv) { 989 // this is basically copied from the logic of search_predict() 990 switch (priv.state) { 991 case INITIALIZE: return false; 992 case GET_TRUTH_STRING: return false; 993 case INIT_TEST: 994 if (priv.beam && (priv.t < priv.beam_actions.size())) 995 return false; 996 return true; 997 case INIT_TRAIN: 998 if (priv.beam && (priv.t < priv.beam_actions.size())) 999 return false; 1000 break; 1001 case LEARN: 1002 if (priv.t < priv.learn_t) return false; 1003 if (priv.t == priv.learn_t) return true; // SPEEDUP: we really only need it on the last learn_a, but this is hard to know... 1004 // t > priv.learn_t 1005 if ((priv.rollout_num_steps > 0) && (priv.loss_declared_cnt >= priv.rollout_num_steps)) return false; // skipping 1006 break; 1007 } 1008 1009 int pol = choose_policy(priv, false); // choose a policy but don't advance prng 1010 return (pol != -1); 1011 } 1012 1013 // note: ec_cnt should be 1 if we are not LDF search_predict(search_private & priv,example * ecs,size_t ec_cnt,ptag mytag,const action * oracle_actions,size_t oracle_actions_cnt,const ptag * condition_on,const char * condition_on_names,const action * allowed_actions,size_t allowed_actions_cnt,size_t learner_id)1014 action search_predict(search_private& priv, example* ecs, size_t ec_cnt, ptag mytag, const action* oracle_actions, size_t oracle_actions_cnt, const ptag* condition_on, const char* condition_on_names, const action* allowed_actions, size_t allowed_actions_cnt, size_t learner_id) { 1015 size_t condition_on_cnt = condition_on_names ? strlen(condition_on_names) : 0; 1016 size_t t = priv.t; 1017 priv.t++; 1018 1019 // make sure parameters come in pairs correctly 1020 assert((oracle_actions == NULL) == (oracle_actions_cnt == 0)); 1021 assert((condition_on == NULL) == (condition_on_names == NULL)); 1022 assert((allowed_actions == NULL) == (allowed_actions_cnt == 0)); 1023 1024 // if we're just after the string, choose an oracle action 1025 if (priv.state == GET_TRUTH_STRING) 1026 return choose_oracle_action(priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt, false); 1027 1028 // if we're in LEARN mode and before learn_t, return the train action 1029 if ((priv.state == LEARN) && (t < priv.learn_t)) { 1030 assert(t < priv.train_trajectory.size()); 1031 return priv.train_trajectory[t]; 1032 } 1033 1034 if (priv.beam && (t < priv.beam_actions.size()) && ((priv.state == INIT_TEST) || (priv.state == INIT_TRAIN))) 1035 return priv.beam_actions[t]; 1036 1037 // for LDF, # of valid actions is ec_cnt; otherwise it's either allowed_actions_cnt or A 1038 size_t valid_action_cnt = priv.is_ldf ? ec_cnt : 1039 (allowed_actions_cnt > 0) ? allowed_actions_cnt : priv.A; 1040 1041 // if we're in LEARN mode and _at_ learn_t, then: 1042 // - choose the next action 1043 // - decide if we're done 1044 // - if we are, then copy/mark the example ref 1045 if ((priv.state == LEARN) && (t == priv.learn_t)) { 1046 action a = (action)priv.learn_a_idx; 1047 priv.loss_declared_cnt = 0; 1048 1049 priv.learn_a_idx++; 1050 priv.learn_loss = 0.; // don't include "past cost" 1051 1052 // check to see if we're done with available actions 1053 if (priv.learn_a_idx >= valid_action_cnt) { 1054 priv.done_with_all_actions = true; 1055 priv.learn_learner_id = learner_id; 1056 1057 // set reference or copy example(s) 1058 if (oracle_actions_cnt > 0) priv.learn_oracle_action = oracle_actions[0]; 1059 priv.learn_ec_ref_cnt = ec_cnt; 1060 if (priv.examples_dont_change) 1061 priv.learn_ec_ref = ecs; 1062 else { 1063 size_t label_size = priv.is_ldf ? sizeof(CS::label) : sizeof(MC::label_t); 1064 void (*label_copy_fn)(void*,void*) = priv.is_ldf ? CS::cs_label.copy_label : NULL; 1065 1066 ensure_size(priv.learn_ec_copy, ec_cnt); 1067 for (size_t i=0; i<ec_cnt; i++) 1068 VW::copy_example_data(priv.all->audit, priv.learn_ec_copy.begin+i, ecs+i, label_size, label_copy_fn); 1069 1070 priv.learn_ec_ref = priv.learn_ec_copy.begin; 1071 } 1072 1073 // copy conditioning stuff and allowed actions 1074 if (priv.auto_condition_features) { 1075 ensure_size(priv.learn_condition_on, condition_on_cnt); 1076 ensure_size(priv.learn_condition_on_act, condition_on_cnt); 1077 1078 priv.learn_condition_on.end = priv.learn_condition_on.begin + condition_on_cnt; // allow .size() to be used in lieu of _cnt 1079 1080 memcpy(priv.learn_condition_on.begin, condition_on, condition_on_cnt * sizeof(ptag)); 1081 1082 for (size_t i=0; i<condition_on_cnt; i++) 1083 push_at(priv.learn_condition_on_act, ((1 <= condition_on[i]) && (condition_on[i] < priv.ptag_to_action.size())) ? priv.ptag_to_action[condition_on[i]] : 0, i); 1084 1085 if (condition_on_names == NULL) { 1086 ensure_size(priv.learn_condition_on_names, 1); 1087 priv.learn_condition_on_names[0] = 0; 1088 } else { 1089 ensure_size(priv.learn_condition_on_names, strlen(condition_on_names)+1); 1090 strcpy(priv.learn_condition_on_names.begin, condition_on_names); 1091 } 1092 } 1093 1094 ensure_size(priv.learn_allowed_actions, allowed_actions_cnt); 1095 memcpy(priv.learn_allowed_actions.begin, allowed_actions, allowed_actions_cnt*sizeof(action)); 1096 cdbg_print_array("in LEARN, learn_allowed_actions", priv.learn_allowed_actions); 1097 } 1098 1099 assert((allowed_actions_cnt == 0) || (a < allowed_actions_cnt)); 1100 return (allowed_actions_cnt > 0) ? allowed_actions[a] : priv.is_ldf ? a : (a+1); 1101 } 1102 1103 if ((priv.state == LEARN) && (t > priv.learn_t) && (priv.rollout_num_steps > 0) && (priv.loss_declared_cnt >= priv.rollout_num_steps)) { 1104 cdbg << "... skipping" << endl; 1105 if (priv.is_ldf) return 0; 1106 else if (allowed_actions_cnt > 0) return allowed_actions[0]; 1107 else return 1; 1108 } 1109 1110 1111 if ((priv.state == INIT_TRAIN) || 1112 (priv.state == INIT_TEST) || 1113 ((priv.state == LEARN) && (t > priv.learn_t))) { 1114 // we actually need to run the policy 1115 1116 int policy = choose_policy(priv); 1117 action a; 1118 1119 cdbg << "executing policy " << policy << endl; 1120 1121 bool gte_here = (priv.state == INIT_TRAIN) && (priv.rollout_method == NO_ROLLOUT) && (oracle_actions_cnt > 0); 1122 1123 if (policy == -1) 1124 a = choose_oracle_action(priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt, priv.beam && (priv.state != INIT_TEST)); 1125 1126 if ((policy >= 0) || gte_here) { 1127 int learner = select_learner(priv, policy, learner_id); 1128 1129 ensure_size(priv.condition_on_actions, condition_on_cnt); 1130 for (size_t i=0; i<condition_on_cnt; i++) 1131 priv.condition_on_actions[i] = ((1 <= condition_on[i]) && (condition_on[i] < priv.ptag_to_action.size())) ? priv.ptag_to_action[condition_on[i]] : 0; 1132 1133 if (cached_action_store_or_find(priv, mytag, condition_on, condition_on_names, priv.condition_on_actions.begin, condition_on_cnt, policy, learner_id, a, false)) 1134 // if this succeeded, 'a' has the right action 1135 priv.total_cache_hits++; 1136 else { // we need to predict, and then cache 1137 size_t start_K = (priv.is_ldf && LabelDict::ec_is_example_header(ecs[0])) ? 1 : 0; 1138 if (priv.auto_condition_features) 1139 for (size_t n=start_K; n<ec_cnt; n++) 1140 add_example_conditioning(priv, ecs[n], condition_on, condition_on_cnt, condition_on_names, priv.condition_on_actions.begin); 1141 1142 if (policy >= 0) // only make a prediction if we're going to use the output 1143 a = priv.is_ldf ? single_prediction_LDF(priv, ecs, ec_cnt, learner) 1144 : single_prediction_notLDF(priv, *ecs, learner, allowed_actions, allowed_actions_cnt); 1145 1146 if (gte_here) { 1147 cdbg << "INIT_TRAIN, NO_ROLLOUT, at least one oracle_actions" << endl; 1148 // we can generate a training example _NOW_ because we're not doing rollouts 1149 v_array<float> losses = v_init<float>(); // SPEEDUP: move this to data structure 1150 allowed_actions_to_losses(priv, ec_cnt, allowed_actions, allowed_actions_cnt, oracle_actions, oracle_actions_cnt, losses); 1151 cdbg_print_array("losses", losses); 1152 priv.learn_ec_ref = ecs; 1153 priv.learn_ec_ref_cnt = ec_cnt; 1154 ensure_size(priv.learn_allowed_actions, allowed_actions_cnt); 1155 memcpy(priv.learn_allowed_actions.begin, allowed_actions, allowed_actions_cnt * sizeof(action)); 1156 generate_training_example(priv, losses, false); 1157 losses.delete_v(); 1158 } 1159 1160 if (priv.auto_condition_features) 1161 for (size_t n=start_K; n<ec_cnt; n++) 1162 del_example_conditioning(priv, ecs[n]); 1163 1164 cached_action_store_or_find(priv, mytag, condition_on, condition_on_names, priv.condition_on_actions.begin, condition_on_cnt, policy, learner_id, a, true); 1165 } 1166 } 1167 1168 if (priv.state == INIT_TRAIN) 1169 priv.train_trajectory.push_back(a); // note the action for future reference 1170 1171 return a; 1172 } 1173 1174 std::cerr << "error: predict called in unknown state" << endl; 1175 throw exception(); 1176 } 1177 cmp_size_t(const size_t a,const size_t b)1178 inline bool cmp_size_t(const size_t a, const size_t b) { return a < b; } cmp_size_t_pair(const pair<size_t,size_t> & a,const pair<size_t,size_t> & b)1179 inline bool cmp_size_t_pair(const pair<size_t,size_t>& a, const pair<size_t,size_t>& b) { return ((a.first == b.first) && (a.second < b.second)) || (a.first < b.first); } get_training_timesteps(search_private & priv,v_array<pair<size_t,size_t>> & timesteps)1180 void get_training_timesteps(search_private& priv, v_array< pair<size_t,size_t> >& timesteps) { // timesteps are pairs of (beam elem, t) where beam elem == 0 means "default" for non-beam search 1181 timesteps.erase(); 1182 1183 // if there's no subsampling to do, just return [0,T) 1184 if (priv.subsample_timesteps <= 0) 1185 for (size_t t=0; t<priv.T; t++) 1186 timesteps.push_back(pair<size_t,size_t>(0,t)); 1187 1188 // if subsample in (0,1) then pick steps with that probability, but ensuring there's at least one! 1189 else if (priv.subsample_timesteps < 1) { 1190 for (size_t t=0; t<priv.T; t++) 1191 if (frand48() <= priv.subsample_timesteps) 1192 timesteps.push_back(pair<size_t,size_t>(0,t)); 1193 1194 if (timesteps.size() == 0) // ensure at least one 1195 timesteps.push_back(pair<size_t,size_t>(0,(size_t)(frand48() * priv.T))); 1196 } 1197 1198 // finally, if subsample >= 1, then pick (int) that many uniformly at random without replacement; could use an LFSR but why? :P 1199 else { 1200 while ((timesteps.size() < (size_t)priv.subsample_timesteps) && 1201 (timesteps.size() < priv.T)) { 1202 size_t t = (size_t)(frand48() * (float)priv.T); 1203 if (! v_array_contains(timesteps, pair<size_t,size_t>(0,t))) 1204 timesteps.push_back(pair<size_t,size_t>(0,t)); 1205 } 1206 std::sort(timesteps.begin, timesteps.end, cmp_size_t_pair); 1207 } 1208 } 1209 get_training_timesteps_beam(search_private & priv,Beam::beam<pair<action_prefix *,string>> & final_beam,v_array<pair<size_t,size_t>> & timesteps)1210 void get_training_timesteps_beam(search_private& priv, Beam::beam< pair<action_prefix*, string> >& final_beam, v_array< pair<size_t,size_t> >& timesteps) { 1211 timesteps.erase(); 1212 if (priv.subsample_timesteps <= 0) { 1213 size_t id = 0; 1214 for (Beam::beam_element< pair<action_prefix*, string> >* item = final_beam.begin(); item != final_beam.end(); ++item) { 1215 //cerr << id << " " << item->active << " " << item->data->first->size() << endl; 1216 if (item->active) 1217 for (size_t t=0; t<item->data->first->size(); t++) 1218 timesteps.push_back( pair<size_t,size_t>( id, t ) ); 1219 id ++; 1220 } 1221 } else { 1222 std::cerr << "error: cannot do subsampling of timesteps with beam search yet!" << endl; 1223 throw exception(); 1224 } 1225 } 1226 free_action_prefix(action_prefix * px)1227 void free_action_prefix(action_prefix* px) { 1228 px->delete_v(); 1229 delete px; 1230 } 1231 free_action_prefix_string_pair(pair<action_prefix *,string> * p)1232 void free_action_prefix_string_pair(pair<action_prefix*,string>* p) { 1233 p->first->delete_v(); 1234 delete p->first; 1235 delete p; 1236 } 1237 final_beam_insert(search_private & priv,Beam::beam<pair<action_prefix *,string>> & beam,float cost,SearchState state)1238 void final_beam_insert(search_private&priv, Beam::beam< pair<action_prefix*,string> >& beam, float cost, SearchState state) { 1239 action_prefix* final = new action_prefix; // TODO: can we memcpy/push_many? 1240 *final = v_init<action>(); 1241 //cerr << "final_beam_insert: cost=" << cost << ", len=" << ((state == INIT_TEST) ? priv.test_action_sequence.size() : priv.train_trajectory.size()) << endl; 1242 if (state == INIT_TEST) 1243 for (size_t i=0; i<priv.test_action_sequence.size(); i++) final->push_back(priv.test_action_sequence[i]); 1244 else if (state == INIT_TRAIN) 1245 for (size_t i=0; i<priv.train_trajectory.size(); i++) final->push_back(priv.train_trajectory[i]); 1246 //cerr << " --> ["; for (size_t i=0; i<final->size(); i++) cerr << " " << final->get(i); cerr << " ]" << endl; 1247 pair<action_prefix*,string>* p = priv.should_produce_string ? new pair<action_prefix*,string>(final, priv.pred_string->str()) : new pair<action_prefix*,string>(final, ""); 1248 uint32_t final_hash = uniform_hash(final->begin, sizeof(action)*final->size(), 3419); 1249 if (!beam.insert(p, cost, final_hash)) { 1250 final->delete_v(); 1251 delete final; 1252 delete p; 1253 } 1254 } 1255 beam_predict(search & sch,SearchState state)1256 Beam::beam< pair<action_prefix*, string> >* beam_predict(search& sch, SearchState state) { 1257 search_private& priv = *sch.priv; 1258 vw&all = *priv.all; 1259 bool old_no_caching = priv.no_caching; // caching is incompatible with generating beam rollouts 1260 priv.no_caching = true; 1261 1262 priv.beam->erase(free_action_prefix); 1263 clear_cache_hash_map(priv); 1264 1265 reset_search_structure(priv); 1266 priv.beam_actions.erase(); 1267 priv.state = state; 1268 priv.should_produce_string = (state == INIT_TEST) && (might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0)); 1269 1270 // do the initial prediction 1271 priv.beam_initial_cost = 0.; 1272 priv.beam_total_cost = 0.; 1273 if (state == INIT_TEST) priv.test_action_sequence.clear(); 1274 else if (state == INIT_TRAIN) priv.train_trajectory.erase(); 1275 if (priv.should_produce_string) priv.pred_string->str(""); 1276 priv.task->run(sch, priv.ec_seq); 1277 if (all.raw_prediction > 0) all.print_text(all.raw_prediction, "end of initial beam prediction", priv.ec_seq[0]->tag); 1278 1279 size_t final_size = (state == INIT_TEST) ? max(1, priv.kbest) : max(1, priv.beam->get_beam_size()); // at training time, use beam size 1280 //cerr << "final_size = " << final_size << endl; 1281 1282 Beam::beam< pair<action_prefix*, string> >* final_beam = new Beam::beam< pair<action_prefix*, string> >(final_size); 1283 final_beam_insert(priv, *final_beam, priv.beam_total_cost, state); 1284 1285 for (size_t beam_run=1; beam_run<priv.beam->get_beam_size(); beam_run++) { 1286 //cerr << "beam_run=" << beam_run << endl; 1287 priv.beam->compact(free_action_prefix); 1288 Beam::beam_element<action_prefix>* item = priv.beam->pop_best_item(); 1289 if (item != NULL) { 1290 reset_search_structure(priv); 1291 priv.beam_actions.erase(); 1292 priv.state = state; 1293 priv.should_produce_string = (state == INIT_TEST) && (might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0)); 1294 if (priv.should_produce_string) priv.pred_string->str(""); 1295 priv.beam_initial_cost = *((float*)(item->data->begin+item->data->size()-1)); 1296 priv.beam_total_cost = priv.beam_initial_cost; 1297 push_many(priv.beam_actions, item->data->begin, item->data->size() - 1); 1298 if (state == INIT_TEST) priv.test_action_sequence.clear(); 1299 else if (state == INIT_TRAIN) priv.train_trajectory.erase(); 1300 priv.task->run(sch, priv.ec_seq); 1301 if (all.raw_prediction > 0) all.print_text(all.raw_prediction, "end of next beam prediction", priv.ec_seq[0]->tag); 1302 final_beam_insert(priv, *final_beam, priv.beam_total_cost, state); 1303 } 1304 } 1305 1306 final_beam->compact(free_action_prefix_string_pair); 1307 Beam::beam_element< pair<action_prefix*,string> >* best = final_beam->begin(); 1308 while ((best != final_beam->end()) && !best->active) ++best; 1309 if (best != final_beam->end()) { 1310 // store in beam_actions the actions for this so that subsequent calls to ->_run() produce it! 1311 priv.beam_actions.erase(); 1312 push_many(priv.beam_actions, best->data->first->begin, best->data->first->size()); 1313 } 1314 1315 // TODO: only if test? 1316 if (all.final_prediction_sink.begin != all.final_prediction_sink.end) { // need to produce prediction output 1317 v_array<char> new_tag = v_init<char>(); 1318 for (; best != final_beam->end(); ++best) 1319 if (best->active) { 1320 new_tag.erase(); 1321 new_tag.resize(50, true); 1322 int len = sprintf(new_tag.begin, "%-10.6f\t", best->cost); 1323 new_tag.end = new_tag.begin + len; 1324 push_many(new_tag, priv.ec_seq[0]->tag.begin, priv.ec_seq[0]->tag.size()); 1325 for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; ++sink) 1326 all.print_text((int)*sink, best->data->second, new_tag); 1327 } 1328 for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; ++sink) 1329 all.print_text((int)*sink, "", priv.ec_seq[0]->tag); 1330 new_tag.delete_v(); 1331 } 1332 1333 priv.no_caching = old_no_caching; 1334 return final_beam; 1335 } 1336 1337 template <bool is_learn> train_single_example(search & sch,bool is_test_ex)1338 void train_single_example(search& sch, bool is_test_ex) { 1339 search_private& priv = *sch.priv; 1340 vw&all = *priv.all; 1341 bool ran_test = false; // we must keep track so that even if we skip test, we still update # of examples seen 1342 1343 clear_cache_hash_map(priv); 1344 Beam::beam< pair<action_prefix*, string> >* final_beam = NULL; 1345 1346 // do an initial test pass to compute output (and loss) 1347 if (must_run_test(all, priv.ec_seq, is_test_ex)) { 1348 cdbg << "======================================== INIT TEST (" << priv.current_policy << "," << priv.read_example_last_pass << ") ========================================" << endl; 1349 1350 ran_test = true; 1351 1352 if (priv.beam) { 1353 final_beam = beam_predict(sch, INIT_TEST); 1354 final_beam->erase(free_action_prefix_string_pair); 1355 delete final_beam; 1356 } 1357 1358 // SPEEDUP: in the case of beam, we could skip some of this... 1359 // do the prediction 1360 reset_search_structure(priv); 1361 priv.state = INIT_TEST; 1362 priv.should_produce_string = might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0); 1363 priv.pred_string->str(""); 1364 priv.test_action_sequence.clear(); 1365 priv.task->run(sch, priv.ec_seq); 1366 1367 // accumulate loss 1368 if (! is_test_ex) 1369 all.sd->update(priv.ec_seq[0]->test_only, priv.test_loss, 1.f, priv.num_features); 1370 1371 // generate output 1372 if (!priv.beam) 1373 for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; ++sink) 1374 all.print_text((int)*sink, priv.pred_string->str(), priv.ec_seq[0]->tag); 1375 1376 if (all.raw_prediction > 0) 1377 all.print_text(all.raw_prediction, "", priv.ec_seq[0]->tag); 1378 } 1379 1380 // if we're not training, then we're done! 1381 if ((!is_learn) || is_test_ex || priv.ec_seq[0]->test_only || (!priv.all->training)) 1382 return; 1383 1384 // SPEEDUP: if the oracle was never called, we can skip this! 1385 1386 // do a pass over the data allowing oracle 1387 cdbg << "======================================== INIT TRAIN (" << priv.current_policy << "," << priv.read_example_last_pass << ") ========================================" << endl; 1388 //cerr << "training" << endl; 1389 final_beam = NULL; 1390 if (priv.beam) 1391 final_beam = beam_predict(sch, INIT_TRAIN); 1392 1393 reset_search_structure(priv); 1394 priv.state = INIT_TRAIN; 1395 priv.train_trajectory.erase(); // this is where we'll store the training sequence 1396 priv.task->run(sch, priv.ec_seq); 1397 1398 if (!ran_test) { // was && !priv.ec_seq[0]->test_only) { but we know it's not test_only 1399 all.sd->weighted_examples += 1.f; 1400 all.sd->total_features += priv.num_features; 1401 all.sd->sum_loss += priv.test_loss; 1402 all.sd->sum_loss_since_last_dump += priv.test_loss; 1403 all.sd->example_number++; 1404 } 1405 1406 // if there's nothing to train on, we're done! 1407 if ((priv.loss_declared_cnt == 0) || (priv.t == 0) || (priv.rollout_method == NO_ROLLOUT)) {// TODO: make sure NO_ROLLOUT works with beam! 1408 if (priv.beam) { 1409 final_beam->erase(free_action_prefix_string_pair); 1410 delete final_beam; 1411 } 1412 return; 1413 } 1414 1415 // otherwise, we have some learn'in to do! 1416 cdbg << "======================================== LEARN (" << priv.current_policy << "," << priv.read_example_last_pass << ") ========================================" << endl; 1417 priv.T = priv.t; 1418 if (priv.beam) get_training_timesteps_beam(priv, *final_beam, priv.timesteps); 1419 else get_training_timesteps(priv, priv.timesteps); 1420 priv.learn_losses.erase(); 1421 //cdbg_print_array("timesteps", priv.timesteps); 1422 size_t last_beam_id = 0; 1423 for (size_t tid=0; tid<priv.timesteps.size(); tid++) { 1424 size_t bid = priv.timesteps[tid].first; 1425 //cerr << "timestep = " << priv.timesteps[tid].first << "." << priv.timesteps[tid].second << " [" << tid << "/" << priv.timesteps.size() << "]" << endl; 1426 if (bid != last_beam_id) { 1427 priv.train_trajectory.erase(); 1428 push_many(priv.train_trajectory, final_beam->begin()[bid].data->first->begin, final_beam->begin()[bid].data->first->size()); 1429 } 1430 1431 priv.learn_a_idx = 0; 1432 priv.done_with_all_actions = false; 1433 // for each action, roll out to get a loss 1434 while (! priv.done_with_all_actions) { 1435 reset_search_structure(priv); 1436 priv.beam_actions.erase(); 1437 priv.state = LEARN; 1438 priv.learn_t = priv.timesteps[tid].second; 1439 cdbg << "learn_t = " << priv.learn_t << ", learn_a_idx = " << priv.learn_a_idx << endl; 1440 priv.task->run(sch, priv.ec_seq); 1441 priv.learn_losses.push_back( priv.learn_loss ); // SPEEDUP: should we just put this in a CS structure from the get-go? 1442 cdbg_print_array("learn_losses", priv.learn_losses); 1443 } 1444 // now we can make a training example 1445 generate_training_example(priv, priv.learn_losses); 1446 if (! priv.examples_dont_change) 1447 for (size_t n=0; n<priv.learn_ec_copy.size(); n++) { 1448 if (sch.priv->is_ldf) CS::cs_label.delete_label(&priv.learn_ec_copy[n].l.cs); 1449 else MC::mc_label.delete_label(&priv.learn_ec_copy[n].l.multi); 1450 } 1451 priv.learn_losses.erase(); 1452 } 1453 if (priv.beam) { 1454 final_beam->erase(free_action_prefix_string_pair); 1455 delete final_beam; 1456 } 1457 } 1458 1459 1460 template <bool is_learn> do_actual_learning(vw & all,search & sch)1461 void do_actual_learning(vw&all, search& sch) { 1462 search_private& priv = *sch.priv; 1463 1464 if (priv.ec_seq.size() == 0) 1465 return; // nothing to do :) 1466 1467 bool is_test_ex = false; 1468 for (size_t i=0; i<priv.ec_seq.size(); i++) 1469 if (priv.label_is_test(&priv.ec_seq[i]->l)) { is_test_ex = true; break; } 1470 1471 if (priv.task->run_setup) priv.task->run_setup(sch, priv.ec_seq); 1472 1473 // if we're going to have to print to the screen, generate the "truth" string 1474 cdbg << "======================================== GET TRUTH STRING (" << priv.current_policy << "," << priv.read_example_last_pass << ") ========================================" << endl; 1475 if (might_print_update(all)) { 1476 if (is_test_ex) 1477 priv.truth_string->str("**test**"); 1478 else { 1479 reset_search_structure(*sch.priv); 1480 priv.beam_actions.erase(); 1481 priv.state = GET_TRUTH_STRING; 1482 priv.should_produce_string = true; 1483 priv.truth_string->str(""); 1484 priv.task->run(sch, priv.ec_seq); 1485 } 1486 } 1487 1488 add_neighbor_features(priv); 1489 train_single_example<is_learn>(sch, is_test_ex); 1490 del_neighbor_features(priv); 1491 1492 if (priv.task->run_takedown) priv.task->run_takedown(sch, priv.ec_seq); 1493 } 1494 1495 template <bool is_learn> search_predict_or_learn(search & sch,base_learner & base,example & ec)1496 void search_predict_or_learn(search& sch, base_learner& base, example& ec) { 1497 search_private& priv = *sch.priv; 1498 vw* all = priv.all; 1499 priv.base_learner = &base; 1500 bool is_real_example = true; 1501 1502 if (example_is_newline(ec) || priv.ec_seq.size() >= all->p->ring_size - 2) { 1503 if (priv.ec_seq.size() >= all->p->ring_size - 2) // -2 to give some wiggle room 1504 std::cerr << "warning: length of sequence at " << ec.example_counter << " exceeds ring size; breaking apart" << std::endl; 1505 1506 do_actual_learning<is_learn>(*all, sch); 1507 1508 priv.hit_new_pass = false; 1509 priv.last_example_was_newline = true; 1510 is_real_example = false; 1511 } else { 1512 if (priv.last_example_was_newline) 1513 priv.ec_seq.clear(); 1514 priv.ec_seq.push_back(&ec); 1515 priv.last_example_was_newline = false; 1516 } 1517 1518 if (is_real_example) 1519 priv.read_example_last_id = ec.example_counter; 1520 } 1521 end_pass(search & sch)1522 void end_pass(search& sch) { 1523 search_private& priv = *sch.priv; 1524 vw* all = priv.all; 1525 priv.hit_new_pass = true; 1526 priv.read_example_last_pass++; 1527 priv.passes_since_new_policy++; 1528 1529 if (priv.passes_since_new_policy >= priv.passes_per_policy) { 1530 priv.passes_since_new_policy = 0; 1531 if(all->training) 1532 priv.current_policy++; 1533 if (priv.current_policy > priv.total_number_of_policies) { 1534 std::cerr << "internal error (bug): too many policies; not advancing" << std::endl; 1535 priv.current_policy = priv.total_number_of_policies; 1536 } 1537 //reset search_trained_nb_policies in options_from_file so it is saved to regressor file later 1538 std::stringstream ss; 1539 ss << priv.current_policy; 1540 VW::cmd_string_replace_value(all->file_options,"--search_trained_nb_policies", ss.str()); 1541 } 1542 } 1543 finish_example(vw & all,search & sch,example & ec)1544 void finish_example(vw& all, search& sch, example& ec) { 1545 if (ec.end_pass || example_is_newline(ec) || sch.priv->ec_seq.size() >= all.p->ring_size - 2) { 1546 print_update(*sch.priv); 1547 VW::finish_example(all, &ec); 1548 clear_seq(all, *sch.priv); 1549 } 1550 } 1551 end_examples(search & sch)1552 void end_examples(search& sch) { 1553 search_private& priv = *sch.priv; 1554 vw* all = priv.all; 1555 1556 do_actual_learning<true>(*all, sch); 1557 1558 if( all->training ) { 1559 std::stringstream ss1; 1560 std::stringstream ss2; 1561 ss1 << ((priv.passes_since_new_policy == 0) ? priv.current_policy : (priv.current_policy+1)); 1562 //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --search_trained_nb_policies 1563 VW::cmd_string_replace_value(all->file_options,"--search_trained_nb_policies", ss1.str()); 1564 ss2 << priv.total_number_of_policies; 1565 //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --search_total_nb_policies 1566 VW::cmd_string_replace_value(all->file_options,"--search_total_nb_policies", ss2.str()); 1567 } 1568 } 1569 mc_label_is_test(void * lab)1570 bool mc_label_is_test(void* lab) { 1571 if (MC::label_is_test((MC::label_t*)lab) > 0) 1572 return true; 1573 else 1574 return false; 1575 } 1576 search_initialize(vw * all,search & sch)1577 void search_initialize(vw* all, search& sch) { 1578 search_private& priv = *sch.priv; 1579 priv.all = all; 1580 1581 priv.auto_condition_features = false; 1582 priv.auto_hamming_loss = false; 1583 priv.examples_dont_change = false; 1584 priv.is_ldf = false; 1585 1586 priv.label_is_test = mc_label_is_test; 1587 1588 priv.A = 1; 1589 priv.num_learners = 1; 1590 priv.cb_learner = false; 1591 priv.state = INITIALIZE; 1592 priv.learn_learner_id = 0; 1593 priv.mix_per_roll_policy = -2; 1594 1595 priv.t = 0; 1596 priv.T = 0; 1597 priv.learn_ec_ref = NULL; 1598 priv.learn_ec_ref_cnt = 0; 1599 //priv.allowed_actions_cache = NULL; 1600 1601 priv.loss_declared_cnt = 0; 1602 priv.learn_t = 0; 1603 priv.learn_a_idx = 0; 1604 priv.done_with_all_actions = false; 1605 1606 priv.test_loss = 0.; 1607 priv.learn_loss = 0.; 1608 priv.train_loss = 0.; 1609 1610 priv.last_example_was_newline = false; 1611 priv.hit_new_pass = false; 1612 1613 priv.printed_output_header = false; 1614 1615 priv.should_produce_string = false; 1616 priv.pred_string = new stringstream(); 1617 priv.truth_string = new stringstream(); 1618 priv.bad_string_stream = new stringstream(); 1619 priv.bad_string_stream->clear(priv.bad_string_stream->badbit); 1620 1621 priv.beta = 0.5; 1622 priv.alpha = 1e-10f; 1623 1624 priv.rollout_method = MIX_PER_ROLL; 1625 priv.rollin_method = MIX_PER_ROLL; 1626 priv.subsample_timesteps = 0.; 1627 1628 priv.allow_current_policy = true; 1629 priv.adaptive_beta = true; 1630 priv.passes_per_policy = 1; //this should be set to the same value as --passes for dagger 1631 1632 priv.current_policy = 0; 1633 1634 priv.num_features = 0; 1635 priv.total_number_of_policies = 1; 1636 priv.read_example_last_id = 0; 1637 priv.passes_per_policy = 0; 1638 priv.read_example_last_pass = 0; 1639 priv.total_examples_generated = 0; 1640 priv.total_predictions_made = 0; 1641 priv.total_cache_hits = 0; 1642 1643 priv.history_length = 1; 1644 priv.acset.max_bias_ngram_length = 1; 1645 priv.acset.max_quad_ngram_length = 0; 1646 priv.acset.feature_value = 1.; 1647 1648 priv.cache_hash_map.set_default_value((action)-1); 1649 priv.cache_hash_map.set_equivalent(cached_item_equivalent); 1650 1651 priv.task = NULL; 1652 sch.task_data = NULL; 1653 1654 priv.empty_example = alloc_examples(sizeof(CS::label), 1); 1655 CS::cs_label.default_label(&priv.empty_example->l.cs); 1656 priv.empty_example->in_use = true; 1657 1658 priv.rawOutputStringStream = new stringstream(priv.rawOutputString); 1659 } 1660 search_finish(search & sch)1661 void search_finish(search& sch) { 1662 search_private& priv = *sch.priv; 1663 cdbg << "search_finish" << endl; 1664 1665 clear_cache_hash_map(priv); 1666 1667 delete priv.truth_string; 1668 delete priv.pred_string; 1669 delete priv.bad_string_stream; 1670 priv.neighbor_features.delete_v(); 1671 priv.timesteps.delete_v(); 1672 priv.learn_losses.delete_v(); 1673 priv.condition_on_actions.delete_v(); 1674 priv.learn_allowed_actions.delete_v(); 1675 priv.ldf_test_label.costs.delete_v(); 1676 1677 if (priv.beam) { 1678 priv.beam->erase(free_action_prefix); 1679 delete priv.beam; 1680 } 1681 priv.beam_actions.delete_v(); 1682 1683 if (priv.cb_learner) 1684 priv.allowed_actions_cache->cb.costs.delete_v(); 1685 else 1686 priv.allowed_actions_cache->cs.costs.delete_v(); 1687 1688 priv.train_trajectory.delete_v(); 1689 priv.current_trajectory.delete_v(); 1690 priv.ptag_to_action.delete_v(); 1691 1692 dealloc_example(CS::cs_label.delete_label, *(priv.empty_example)); 1693 free(priv.empty_example); 1694 1695 priv.ec_seq.clear(); 1696 1697 // destroy copied examples if we needed them 1698 if (! priv.examples_dont_change) { 1699 void (*delete_label)(void*) = priv.is_ldf ? CS::cs_label.delete_label : MC::mc_label.delete_label; 1700 for(example*ec = priv.learn_ec_copy.begin; ec!=priv.learn_ec_copy.end; ++ec) 1701 dealloc_example(delete_label, *ec); 1702 priv.learn_ec_copy.delete_v(); 1703 } 1704 priv.learn_condition_on_names.delete_v(); 1705 priv.learn_condition_on.delete_v(); 1706 priv.learn_condition_on_act.delete_v(); 1707 1708 if (priv.task->finish != NULL) { 1709 priv.task->finish(sch); 1710 } 1711 1712 free(priv.allowed_actions_cache); 1713 delete priv.rawOutputStringStream; 1714 delete sch.priv; 1715 } 1716 ensure_param(float & v,float lo,float hi,float def,const char * string)1717 void ensure_param(float &v, float lo, float hi, float def, const char* string) { 1718 if ((v < lo) || (v > hi)) { 1719 std::cerr << string << endl; 1720 v = def; 1721 } 1722 } 1723 string_equal(string a,string b)1724 bool string_equal(string a, string b) { return a.compare(b) == 0; } float_equal(float a,float b)1725 bool float_equal(float a, float b) { return fabs(a-b) < 1e-6; } uint32_equal(uint32_t a,uint32_t b)1726 bool uint32_equal(uint32_t a, uint32_t b) { return a==b; } size_equal(size_t a,size_t b)1727 bool size_equal(size_t a, size_t b) { return a==b; } 1728 check_option(T & ret,vw & all,po::variables_map & vm,const char * opt_name,bool default_to_cmdline,bool (* equal)(T,T),const char * mismatch_error_string,const char * required_error_string)1729 template<class T> void check_option(T& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, bool(*equal)(T,T), const char* mismatch_error_string, const char* required_error_string) { 1730 if (vm.count(opt_name)) { 1731 ret = vm[opt_name].as<T>(); 1732 *all.file_options << " --" << opt_name << " " << ret; 1733 } else if (strlen(required_error_string)>0) { 1734 std::cerr << required_error_string << endl; 1735 if (! vm.count("help")) 1736 throw exception(); 1737 } 1738 } 1739 check_option(bool & ret,vw & all,po::variables_map & vm,const char * opt_name,bool default_to_cmdline,const char * mismatch_error_string)1740 void check_option(bool& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, const char* mismatch_error_string) { 1741 if (vm.count(opt_name)) { 1742 ret = true; 1743 *all.file_options << " --" << opt_name; 1744 } else 1745 ret = false; 1746 } 1747 handle_condition_options(vw & vw,auto_condition_settings & acset)1748 void handle_condition_options(vw& vw, auto_condition_settings& acset) { 1749 new_options(vw, "Search Auto-conditioning Options") 1750 ("search_max_bias_ngram_length", po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional") 1751 ("search_max_quad_ngram_length", po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)") 1752 ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)"); 1753 add_options(vw); 1754 1755 po::variables_map& vm = vw.vm; 1756 1757 check_option<size_t>(acset.max_bias_ngram_length, vw, vm, "search_max_bias_ngram_length", false, size_equal, 1758 "warning: you specified a different value for --search_max_bias_ngram_length than the one loaded from regressor. proceeding with loaded value: ", ""); 1759 1760 check_option<size_t>(acset.max_quad_ngram_length, vw, vm, "search_max_quad_ngram_length", false, size_equal, 1761 "warning: you specified a different value for --search_max_quad_ngram_length than the one loaded from regressor. proceeding with loaded value: ", ""); 1762 1763 check_option<float> (acset.feature_value, vw, vm, "search_condition_feature_value", false, float_equal, 1764 "warning: you specified a different value for --search_condition_feature_value than the one loaded from regressor. proceeding with loaded value: ", ""); 1765 } 1766 read_allowed_transitions(action A,const char * filename)1767 v_array<CS::label> read_allowed_transitions(action A, const char* filename) { 1768 FILE *f = fopen(filename, "r"); 1769 if (f == NULL) { 1770 std::cerr << "error: could not read file " << filename << " (" << strerror(errno) << "); assuming all transitions are valid" << endl; 1771 throw exception(); 1772 } 1773 1774 bool* bg = (bool*)malloc((A+1)*(A+1) * sizeof(bool)); 1775 int rd,from,to,count=0; 1776 while ((rd = fscanf(f, "%d:%d", &from, &to)) > 0) { 1777 if ((from < 0) || (from > (int)A)) { std::cerr << "warning: ignoring transition from " << from << " because it's out of the range [0," << A << "]" << endl; } 1778 if ((to < 0) || (to > (int)A)) { std::cerr << "warning: ignoring transition to " << to << " because it's out of the range [0," << A << "]" << endl; } 1779 bg[from * (A+1) + to] = true; 1780 count++; 1781 } 1782 fclose(f); 1783 1784 v_array<CS::label> allowed = v_init<CS::label>(); 1785 1786 for (size_t from=0; from<A; from++) { 1787 v_array<CS::wclass> costs = v_init<CS::wclass>(); 1788 1789 for (size_t to=0; to<A; to++) 1790 if (bg[from * (A+1) + to]) { 1791 CS::wclass c = { FLT_MAX, (action)to, 0., 0. }; 1792 costs.push_back(c); 1793 } 1794 1795 CS::label ld = { costs }; 1796 allowed.push_back(ld); 1797 } 1798 free(bg); 1799 1800 std::cerr << "read " << count << " allowed transitions from " << filename << endl; 1801 1802 return allowed; 1803 } 1804 1805 parse_neighbor_features(string & nf_string,search & sch)1806 void parse_neighbor_features(string& nf_string, search&sch) { 1807 search_private& priv = *sch.priv; 1808 priv.neighbor_features.erase(); 1809 size_t len = nf_string.length(); 1810 if (len == 0) return; 1811 1812 char * cstr = new char [len+1]; 1813 strcpy(cstr, nf_string.c_str()); 1814 1815 char * p = strtok(cstr, ","); 1816 v_array<substring> cmd = v_init<substring>(); 1817 while (p != 0) { 1818 cmd.erase(); 1819 substring me = { p, p+strlen(p) }; 1820 tokenize(':', me, cmd, true); 1821 1822 int32_t posn = 0; 1823 char ns = ' '; 1824 if (cmd.size() == 1) { 1825 posn = int_of_substring(cmd[0]); 1826 ns = ' '; 1827 } else if (cmd.size() == 2) { 1828 posn = int_of_substring(cmd[0]); 1829 ns = (cmd[1].end > cmd[1].begin) ? cmd[1].begin[0] : ' '; 1830 } else { 1831 std::cerr << "warning: ignoring malformed neighbor specification: '" << p << "'" << endl; 1832 } 1833 int32_t enc = (posn << 24) | (ns & 0xFF); 1834 priv.neighbor_features.push_back(enc); 1835 1836 p = strtok(NULL, ","); 1837 } 1838 cmd.delete_v(); 1839 1840 delete[] cstr; 1841 } 1842 setup(vw & all)1843 base_learner* setup(vw&all) { 1844 if (missing_option<size_t, false>(all, "search", 1845 "Use learning to search, argument=maximum action id or 0 for LDF")) 1846 return NULL; 1847 new_options(all, "Search Options") 1848 ("search_task", po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)") 1849 ("search_interpolation", po::value<string>(), "at what level should interpolation happen? [*data|policy]") 1850 ("search_rollout", po::value<string>(), "how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]") 1851 ("search_rollin", po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]") 1852 1853 ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy) [def=1]") 1854 ("search_beta", po::value<float>(), "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]") 1855 1856 ("search_alpha", po::value<float>(), "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data) [def=1e-10]") 1857 1858 ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained") 1859 1860 ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file") 1861 1862 ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]") 1863 ("search_subsample_time", po::value<float>(), "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example") 1864 ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them") 1865 ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")") 1866 ("search_history_length", po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]") 1867 1868 ("search_no_caching", "turn off the built-in caching ability (makes things slower, but technically more safe)") 1869 ("search_beam", po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)") 1870 ("search_kbest", po::value<size_t>(), "size of k-best list to produce (must be <= beam size)") 1871 ; 1872 add_options(all); 1873 po::variables_map& vm = all.vm; 1874 1875 bool has_hook_task = false; 1876 for (size_t i=0; i<all.args.size()-1; i++) 1877 if (all.args[i] == "--search_task" && all.args[i+1] == "hook") 1878 has_hook_task = true; 1879 if (has_hook_task) 1880 for (int i = (int)all.args.size()-2; i >= 0; i--) 1881 if (all.args[i] == "--search_task" && all.args[i+1] != "hook") 1882 all.args.erase(all.args.begin() + i, all.args.begin() + i + 2); 1883 1884 search& sch = calloc_or_die<search>(); 1885 sch.priv = new search_private(); 1886 search_initialize(&all, sch); 1887 search_private& priv = *sch.priv; 1888 1889 std::string task_string; 1890 std::string interpolation_string = "data"; 1891 std::string rollout_string = "mix_per_state"; 1892 std::string rollin_string = "mix_per_state"; 1893 1894 check_option<string>(task_string, all, vm, "search_task", false, string_equal, 1895 "warning: specified --search_task different than the one loaded from regressor. using loaded value of: ", 1896 "error: you must specify a task using --search_task"); 1897 1898 check_option<string>(interpolation_string, all, vm, "search_interpolation", false, string_equal, 1899 "warning: specified --search_interpolation different than the one loaded from regressor. using loaded value of: ", ""); 1900 1901 if (vm.count("search_passes_per_policy")) priv.passes_per_policy = vm["search_passes_per_policy"].as<size_t>(); 1902 1903 if (vm.count("search_alpha")) priv.alpha = vm["search_alpha" ].as<float>(); 1904 if (vm.count("search_beta")) priv.beta = vm["search_beta" ].as<float>(); 1905 1906 if (vm.count("search_subsample_time")) priv.subsample_timesteps = vm["search_subsample_time"].as<float>(); 1907 if (vm.count("search_no_caching")) priv.no_caching = true; 1908 if (vm.count("search_rollout_num_steps")) priv.rollout_num_steps = vm["search_rollout_num_steps"].as<size_t>(); 1909 1910 if (vm.count("search_beam")) 1911 priv.beam = new Beam::beam<action_prefix>(vm["search_beam"].as<size_t>()); // TODO: pruning, kbest, equivalence testing 1912 else 1913 priv.beam = NULL; 1914 1915 priv.kbest = 1; 1916 if (vm.count("search_kbest")) { 1917 priv.kbest = max(1, vm["search_kbest"].as<size_t>()); 1918 if (priv.kbest > priv.beam->get_beam_size()) { 1919 std::cerr << "warning: kbest set greater than beam size; shrinking back to " << priv.beam->get_beam_size() << endl; 1920 priv.kbest = priv.beam->get_beam_size(); 1921 } 1922 } 1923 1924 priv.A = vm["search"].as<size_t>(); 1925 1926 string neighbor_features_string; 1927 check_option<string>(neighbor_features_string, all, vm, "search_neighbor_features", false, string_equal, 1928 "warning: you specified a different feature structure with --search_neighbor_features than the one loaded from predictor. using loaded value of: ", ""); 1929 parse_neighbor_features(neighbor_features_string, sch); 1930 1931 if (interpolation_string.compare("data") == 0) { // run as dagger 1932 priv.adaptive_beta = true; 1933 priv.allow_current_policy = true; 1934 priv.passes_per_policy = all.numpasses; 1935 if (priv.current_policy > 1) priv.current_policy = 1; 1936 } else if (interpolation_string.compare("policy") == 0) { 1937 } else { 1938 std::cerr << "error: --search_interpolation must be 'data' or 'policy'" << endl; 1939 throw exception(); 1940 } 1941 1942 if (vm.count("search_rollout")) rollout_string = vm["search_rollout"].as<string>(); 1943 if (vm.count("search_rollin" )) rollin_string = vm["search_rollin" ].as<string>(); 1944 1945 if (rollout_string.compare("policy") == 0) priv.rollout_method = POLICY; 1946 else if (rollout_string.compare("oracle") == 0) priv.rollout_method = ORACLE; 1947 else if (rollout_string.compare("mix_per_state") == 0) priv.rollout_method = MIX_PER_STATE; 1948 else if (rollout_string.compare("mix_per_roll") == 0) priv.rollout_method = MIX_PER_ROLL; 1949 else if (rollout_string.compare("none") == 0) { priv.rollout_method = NO_ROLLOUT; priv.no_caching = true; std::cerr << "no rollout!" << endl; } 1950 else { 1951 std::cerr << "error: --search_rollout must be 'policy', 'oracle', 'mix_per_state', 'mix_per_roll' or 'none'" << endl; 1952 throw exception(); 1953 } 1954 1955 if (rollin_string.compare("policy") == 0) priv.rollin_method = POLICY; 1956 else if (rollin_string.compare("oracle") == 0) priv.rollin_method = ORACLE; 1957 else if (rollin_string.compare("mix_per_state") == 0) priv.rollin_method = MIX_PER_STATE; 1958 else if (rollin_string.compare("mix_per_roll") == 0) priv.rollin_method = MIX_PER_ROLL; 1959 else { 1960 std::cerr << "error: --search_rollin must be 'policy', 'oracle', 'mix_per_state' or 'mix_per_roll'" << endl; 1961 throw exception(); 1962 } 1963 1964 check_option<size_t>(priv.A, all, vm, "search", false, size_equal, 1965 "warning: you specified a different number of actions through --search than the one loaded from predictor. using loaded value of: ", ""); 1966 1967 check_option<size_t>(priv.history_length, all, vm, "search_history_length", false, size_equal, 1968 "warning: you specified a different history length through --search_history_length than the one loaded from predictor. using loaded value of: ", ""); 1969 1970 //check if the base learner is contextual bandit, in which case, we dont rollout all actions. 1971 priv.allowed_actions_cache = &calloc_or_die<polylabel>(); 1972 if (vm.count("cb")) { 1973 priv.cb_learner = true; 1974 CB::cb_label.default_label(priv.allowed_actions_cache); 1975 } else { 1976 priv.cb_learner = false; 1977 CS::cs_label.default_label(priv.allowed_actions_cache); 1978 } 1979 1980 //if we loaded a regressor with -i option, --search_trained_nb_policies contains the number of trained policies in the file 1981 // and --search_total_nb_policies contains the total number of policies in the file 1982 if (vm.count("search_total_nb_policies")) 1983 priv.total_number_of_policies = (uint32_t)vm["search_total_nb_policies"].as<size_t>(); 1984 1985 ensure_param(priv.beta , 0.0, 1.0, 0.5, "warning: search_beta must be in (0,1); resetting to 0.5"); 1986 ensure_param(priv.alpha, 0.0, 1.0, 1e-10f, "warning: search_alpha must be in (0,1); resetting to 1e-10"); 1987 1988 //compute total number of policies we will have at end of training 1989 // we add current_policy for cases where we start from an initial set of policies loaded through -i option 1990 uint32_t tmp_number_of_policies = priv.current_policy; 1991 if( all.training ) 1992 tmp_number_of_policies += (int)ceil(((float)all.numpasses) / ((float)priv.passes_per_policy)); 1993 1994 //the user might have specified the number of policies that will eventually be trained through multiple vw calls, 1995 //so only set total_number_of_policies to computed value if it is larger 1996 cdbg << "current_policy=" << priv.current_policy << " tmp_number_of_policies=" << tmp_number_of_policies << " total_number_of_policies=" << priv.total_number_of_policies << endl; 1997 if( tmp_number_of_policies > priv.total_number_of_policies ) { 1998 priv.total_number_of_policies = tmp_number_of_policies; 1999 if( priv.current_policy > 0 ) //we loaded a file but total number of policies didn't match what is needed for training 2000 std::cerr << "warning: you're attempting to train more classifiers than was allocated initially. Likely to cause bad performance." << endl; 2001 } 2002 2003 //current policy currently points to a new policy we would train 2004 //if we are not training and loaded a bunch of policies for testing, we need to subtract 1 from current policy 2005 //so that we only use those loaded when testing (as run_prediction is called with allow_current to true) 2006 if( !all.training && priv.current_policy > 0 ) 2007 priv.current_policy--; 2008 2009 std::stringstream ss1, ss2; 2010 ss1 << priv.current_policy; VW::cmd_string_replace_value(all.file_options,"--search_trained_nb_policies", ss1.str()); 2011 ss2 << priv.total_number_of_policies; VW::cmd_string_replace_value(all.file_options,"--search_total_nb_policies", ss2.str()); 2012 2013 cdbg << "search current_policy = " << priv.current_policy << " total_number_of_policies = " << priv.total_number_of_policies << endl; 2014 2015 if (task_string.compare("list") == 0) { 2016 std::cerr << endl << "available search tasks:" << endl; 2017 for (search_task** mytask = all_tasks; *mytask != NULL; mytask++) 2018 std::cerr << " " << (*mytask)->task_name << endl; 2019 std::cerr << endl; 2020 exit(0); 2021 } 2022 for (search_task** mytask = all_tasks; *mytask != NULL; mytask++) 2023 if (task_string.compare((*mytask)->task_name) == 0) { 2024 priv.task = *mytask; 2025 sch.task_name = (*mytask)->task_name; 2026 break; 2027 } 2028 if (priv.task == NULL) { 2029 if (! vm.count("help")) { 2030 std::cerr << "fail: unknown task for --search_task '" << task_string << "'; use --search_task list to get a list" << endl; 2031 throw exception(); 2032 } 2033 } 2034 all.p->emptylines_separate_examples = true; 2035 2036 if (count(all.args.begin(), all.args.end(),"--csoaa") == 0 2037 && count(all.args.begin(), all.args.end(),"--csoaa_ldf") == 0 2038 && count(all.args.begin(), all.args.end(),"--wap_ldf") == 0 2039 && count(all.args.begin(), all.args.end(),"--cb") == 0) 2040 { 2041 all.args.push_back("--csoaa"); 2042 stringstream ss; 2043 ss << vm["search"].as<size_t>(); 2044 all.args.push_back(ss.str()); 2045 } 2046 base_learner* base = setup_base(all); 2047 2048 // default to OAA labels unless the task wants to override this (which they can do in initialize) 2049 all.p->lp = MC::mc_label; 2050 if (priv.task) 2051 priv.task->initialize(sch, priv.A, vm); 2052 2053 if (vm.count("search_allowed_transitions")) read_allowed_transitions((action)priv.A, vm["search_allowed_transitions"].as<string>().c_str()); 2054 2055 // set up auto-history if they want it 2056 if (priv.auto_condition_features) { 2057 handle_condition_options(all, priv.acset); 2058 2059 // turn off auto-condition if it's irrelevant 2060 if (((priv.acset.max_bias_ngram_length == 0) && (priv.acset.max_quad_ngram_length == 0)) || 2061 (priv.acset.feature_value == 0.f)) { 2062 std::cerr << "warning: turning off AUTO_CONDITION_FEATURES because settings make it useless" << endl; 2063 priv.auto_condition_features = false; 2064 } 2065 } 2066 2067 if (!priv.allow_current_policy) // if we're not dagger 2068 all.check_holdout_every_n_passes = priv.passes_per_policy; 2069 2070 all.searchstr = &sch; 2071 2072 priv.start_clock_time = clock(); 2073 2074 learner<search>& l = init_learner(&sch, base, 2075 search_predict_or_learn<true>, 2076 search_predict_or_learn<false>, 2077 priv.total_number_of_policies); 2078 l.set_finish_example(finish_example); 2079 l.set_end_examples(end_examples); 2080 l.set_finish(search_finish); 2081 l.set_end_pass(end_pass); 2082 2083 return make_base(l); 2084 } 2085 action_hamming_loss(action a,const action * A,size_t sz)2086 float action_hamming_loss(action a, const action* A, size_t sz) { 2087 if (sz == 0) return 0.; // latent variables have zero loss 2088 for (size_t i=0; i<sz; i++) 2089 if (a == A[i]) return 0.; 2090 return 1.; 2091 } 2092 2093 // the interface: is_ldf()2094 bool search::is_ldf() { return this->priv->is_ldf; } 2095 predict(example & ec,ptag mytag,const action * oracle_actions,size_t oracle_actions_cnt,const ptag * condition_on,const char * condition_on_names,const action * allowed_actions,size_t allowed_actions_cnt,size_t learner_id)2096 action search::predict(example& ec, ptag mytag, const action* oracle_actions, size_t oracle_actions_cnt, const ptag* condition_on, const char* condition_on_names, const action* allowed_actions, size_t allowed_actions_cnt, size_t learner_id) { 2097 action a = search_predict(*this->priv, &ec, 1, mytag, oracle_actions, oracle_actions_cnt, condition_on, condition_on_names, allowed_actions, allowed_actions_cnt, learner_id); 2098 if (priv->beam) priv->current_trajectory.push_back(a); 2099 if (priv->state == INIT_TEST) priv->test_action_sequence.push_back(a); 2100 if (mytag != 0) push_at(priv->ptag_to_action, a, mytag); 2101 if (this->priv->auto_hamming_loss) 2102 loss(action_hamming_loss(a, oracle_actions, oracle_actions_cnt)); 2103 cdbg << "predict returning " << a << endl; 2104 return a; 2105 } 2106 predictLDF(example * ecs,size_t ec_cnt,ptag mytag,const action * oracle_actions,size_t oracle_actions_cnt,const ptag * condition_on,const char * condition_on_names,size_t learner_id)2107 action search::predictLDF(example* ecs, size_t ec_cnt, ptag mytag, const action* oracle_actions, size_t oracle_actions_cnt, const ptag* condition_on, const char* condition_on_names, size_t learner_id) { 2108 action a = search_predict(*this->priv, ecs, ec_cnt, mytag, oracle_actions, oracle_actions_cnt, condition_on, condition_on_names, NULL, 0, learner_id); 2109 if (priv->beam) priv->current_trajectory.push_back(a); 2110 if (priv->state == INIT_TEST) priv->test_action_sequence.push_back(a); 2111 if ((mytag != 0) && ecs[a].l.cs.costs.size() > 0) 2112 push_at(priv->ptag_to_action, ecs[a].l.cs.costs[0].class_index, mytag); 2113 if (this->priv->auto_hamming_loss) 2114 loss(action_hamming_loss(a, oracle_actions, oracle_actions_cnt)); 2115 cdbg << "predict returning " << a << endl; 2116 return a; 2117 } 2118 loss(float loss)2119 void search::loss(float loss) { search_declare_loss(*this->priv, loss); } 2120 predictNeedsExample()2121 bool search::predictNeedsExample() { return search_predictNeedsExample(*this->priv); } 2122 output()2123 stringstream& search::output() { 2124 if (!this->priv->should_produce_string ) return *(this->priv->bad_string_stream); 2125 else if ( this->priv->state == GET_TRUTH_STRING) return *(this->priv->truth_string); 2126 else return *(this->priv->pred_string); 2127 } 2128 set_options(uint32_t opts)2129 void search::set_options(uint32_t opts) { 2130 if (this->priv->state != INITIALIZE) { 2131 std::cerr << "error: task cannot set options except in initialize function!" << endl; 2132 throw exception(); 2133 } 2134 if ((opts & AUTO_CONDITION_FEATURES) != 0) this->priv->auto_condition_features = true; 2135 if ((opts & AUTO_HAMMING_LOSS) != 0) this->priv->auto_hamming_loss = true; 2136 if ((opts & EXAMPLES_DONT_CHANGE) != 0) this->priv->examples_dont_change = true; 2137 if ((opts & IS_LDF) != 0) this->priv->is_ldf = true; 2138 } 2139 set_label_parser(label_parser & lp,bool (* is_test)(void *))2140 void search::set_label_parser(label_parser&lp, bool (*is_test)(void*)) { 2141 if (this->priv->state != INITIALIZE) { 2142 std::cerr << "error: task cannot set label parser except in initialize function!" << endl; 2143 throw exception(); 2144 } 2145 this->priv->all->p->lp = lp; 2146 this->priv->label_is_test = is_test; 2147 } 2148 get_test_action_sequence(vector<action> & V)2149 void search::get_test_action_sequence(vector<action>& V) { 2150 V.clear(); 2151 for (size_t i=0; i<this->priv->test_action_sequence.size(); i++) 2152 V.push_back(this->priv->test_action_sequence[i]); 2153 } 2154 2155 set_num_learners(size_t num_learners)2156 void search::set_num_learners(size_t num_learners) { this->priv->num_learners = num_learners; } add_program_options(po::variables_map & vw,po::options_description & opts)2157 void search::add_program_options(po::variables_map& vw, po::options_description& opts) { add_options( *this->priv->all, opts ); } 2158 get_mask()2159 size_t search::get_mask() { return this->priv->all->reg.weight_mask;} get_stride_shift()2160 size_t search::get_stride_shift() { return this->priv->all->reg.stride_shift;} get_history_length()2161 uint32_t search::get_history_length() { return (uint32_t)this->priv->history_length; } 2162 get_vw_pointer_unsafe()2163 vw& search::get_vw_pointer_unsafe() { return *this->priv->all; } 2164 2165 // predictor implementation predictor(search & sch,ptag my_tag)2166 predictor::predictor(search& sch, ptag my_tag) : is_ldf(false), my_tag(my_tag), ec(NULL), ec_cnt(0), ec_alloced(false), oracle_is_pointer(false), allowed_is_pointer(false), learner_id(0), sch(sch) { 2167 oracle_actions = v_init<action>(); 2168 condition_on_tags = v_init<ptag>(); 2169 condition_on_names = v_init<char>(); 2170 allowed_actions = v_init<action>(); 2171 } 2172 free_ec()2173 void predictor::free_ec() { 2174 if (ec_alloced) { 2175 if (is_ldf) 2176 for (size_t i=0; i<ec_cnt; i++) 2177 dealloc_example(CS::cs_label.delete_label, ec[i]); 2178 else 2179 dealloc_example(NULL, *ec); 2180 free(ec); 2181 } 2182 } 2183 ~predictor()2184 predictor::~predictor() { 2185 if (! oracle_is_pointer) oracle_actions.delete_v(); 2186 if (! allowed_is_pointer) allowed_actions.delete_v(); 2187 free_ec(); 2188 condition_on_tags.delete_v(); 2189 condition_on_names.delete_v(); 2190 } 2191 set_input(example & input_example)2192 predictor& predictor::set_input(example&input_example) { 2193 free_ec(); 2194 is_ldf = false; 2195 ec = &input_example; 2196 ec_cnt = 1; 2197 ec_alloced = false; 2198 return *this; 2199 } 2200 set_input(example * input_example,size_t input_length)2201 predictor& predictor::set_input(example*input_example, size_t input_length) { 2202 free_ec(); 2203 is_ldf = true; 2204 ec = input_example; 2205 ec_cnt = input_length; 2206 ec_alloced = false; 2207 return *this; 2208 } 2209 set_input_length(size_t input_length)2210 void predictor::set_input_length(size_t input_length) { 2211 is_ldf = true; 2212 if (ec_alloced) ec = (example*)realloc(ec, input_length * sizeof(example)); 2213 else ec = calloc_or_die<example>(input_length); 2214 ec_cnt = input_length; 2215 ec_alloced = true; 2216 } set_input_at(size_t posn,example & ex)2217 void predictor::set_input_at(size_t posn, example&ex) { 2218 if (!ec_alloced) { std::cerr << "call to set_input_at without previous call to set_input_length" << endl; throw exception(); } 2219 if (posn >= ec_cnt) { std::cerr << "call to set_input_at with too large a position" << endl; throw exception(); } 2220 VW::copy_example_data(false, ec+posn, &ex, CS::cs_label.label_size, CS::cs_label.copy_label); // TODO: the false is "audit" 2221 } 2222 make_new_pointer(v_array<action> & A,size_t new_size)2223 void predictor::make_new_pointer(v_array<action>& A, size_t new_size) { 2224 size_t old_size = A.size(); 2225 action* old_pointer = A.begin; 2226 A.begin = calloc_or_die<action>(new_size); 2227 A.end = A.begin + new_size; 2228 A.end_array = A.end; 2229 memcpy(A.begin, old_pointer, old_size * sizeof(action)); 2230 } 2231 add_to(v_array<action> & A,bool & A_is_ptr,action a,bool clear_first)2232 predictor& predictor::add_to(v_array<action>& A, bool& A_is_ptr, action a, bool clear_first) { 2233 if (A_is_ptr) { // we need to make our own memory 2234 if (clear_first) 2235 A.end = A.begin; 2236 size_t new_size = clear_first ? 1 : (A.size() + 1); 2237 make_new_pointer(A, new_size); 2238 A_is_ptr = false; 2239 A[new_size-1] = a; 2240 } else { // we've already allocated our own memory 2241 if (clear_first) A.erase(); 2242 A.push_back(a); 2243 } 2244 return *this; 2245 } 2246 add_to(v_array<action> & A,bool & A_is_ptr,action * a,size_t action_count,bool clear_first)2247 predictor& predictor::add_to(v_array<action>&A, bool& A_is_ptr, action*a, size_t action_count, bool clear_first) { 2248 size_t old_size = A.size(); 2249 if (old_size > 0) { 2250 if (A_is_ptr) { // we need to make our own memory 2251 if (clear_first) { 2252 A.end = A.begin; 2253 old_size = 0; 2254 } 2255 size_t new_size = old_size + action_count; 2256 make_new_pointer(A, new_size); 2257 A_is_ptr = false; 2258 memcpy(A.begin + old_size, a, action_count * sizeof(action)); 2259 } else { // we already have our own memory 2260 if (clear_first) A.erase(); 2261 push_many<action>(A, a, action_count); 2262 } 2263 } else { // old_size == 0, clear_first is irrelevant 2264 if (! A_is_ptr) 2265 A.delete_v(); // avoid memory leak 2266 2267 A.begin = a; 2268 A.end = a + action_count; 2269 A.end_array = A.end; 2270 A_is_ptr = true; 2271 } 2272 return *this; 2273 } 2274 erase_oracles()2275 predictor& predictor::erase_oracles() { if (oracle_is_pointer) oracle_actions.end = oracle_actions.begin; else oracle_actions.erase(); return *this; } add_oracle(action a)2276 predictor& predictor::add_oracle(action a) { return add_to(oracle_actions, oracle_is_pointer, a, false); } add_oracle(action * a,size_t action_count)2277 predictor& predictor::add_oracle(action*a, size_t action_count) { return add_to(oracle_actions, oracle_is_pointer, a, action_count, false); } add_oracle(v_array<action> & a)2278 predictor& predictor::add_oracle(v_array<action>& a) { return add_to(oracle_actions, oracle_is_pointer, a.begin, a.size(), false); } 2279 set_oracle(action a)2280 predictor& predictor::set_oracle(action a) { return add_to(oracle_actions, oracle_is_pointer, a, true); } set_oracle(action * a,size_t action_count)2281 predictor& predictor::set_oracle(action*a, size_t action_count) { return add_to(oracle_actions, oracle_is_pointer, a, action_count, true); } set_oracle(v_array<action> & a)2282 predictor& predictor::set_oracle(v_array<action>& a) { return add_to(oracle_actions, oracle_is_pointer, a.begin, a.size(), true); } 2283 erase_alloweds()2284 predictor& predictor::erase_alloweds() { if (allowed_is_pointer) allowed_actions.end = allowed_actions.begin; else allowed_actions.erase(); return *this; } add_allowed(action a)2285 predictor& predictor::add_allowed(action a) { return add_to(allowed_actions, allowed_is_pointer, a, false); } add_allowed(action * a,size_t action_count)2286 predictor& predictor::add_allowed(action*a, size_t action_count) { return add_to(allowed_actions, allowed_is_pointer, a, action_count, false); } add_allowed(v_array<action> & a)2287 predictor& predictor::add_allowed(v_array<action>& a) { return add_to(allowed_actions, allowed_is_pointer, a.begin, a.size(), false); } 2288 set_allowed(action a)2289 predictor& predictor::set_allowed(action a) { return add_to(allowed_actions, allowed_is_pointer, a, true); } set_allowed(action * a,size_t action_count)2290 predictor& predictor::set_allowed(action*a, size_t action_count) { return add_to(allowed_actions, allowed_is_pointer, a, action_count, true); } set_allowed(v_array<action> & a)2291 predictor& predictor::set_allowed(v_array<action>& a) { return add_to(allowed_actions, allowed_is_pointer, a.begin, a.size(), true); } 2292 add_condition(ptag tag,char name)2293 predictor& predictor::add_condition(ptag tag, char name) { condition_on_tags.push_back(tag); condition_on_names.push_back(name); return *this; } set_condition(ptag tag,char name)2294 predictor& predictor::set_condition(ptag tag, char name) { condition_on_tags.erase(); condition_on_names.erase(); return add_condition(tag, name); } 2295 add_condition_range(ptag hi,ptag count,char name0)2296 predictor& predictor::add_condition_range(ptag hi, ptag count, char name0) { 2297 if (count == 0) return *this; 2298 for (ptag i=0; i<count; i++) { 2299 if (i > hi) break; 2300 char name = name0 + i; 2301 condition_on_tags.push_back(hi-i); 2302 condition_on_names.push_back(name); 2303 } 2304 return *this; 2305 } set_condition_range(ptag hi,ptag count,char name0)2306 predictor& predictor::set_condition_range(ptag hi, ptag count, char name0) { condition_on_tags.erase(); condition_on_names.erase(); return add_condition_range(hi, count, name0); } 2307 set_learner_id(size_t id)2308 predictor& predictor::set_learner_id(size_t id) { learner_id = id; return *this; } 2309 set_tag(ptag tag)2310 predictor& predictor::set_tag(ptag tag) { my_tag = tag; return *this; } 2311 predict()2312 action predictor::predict() { 2313 const action* orA = oracle_actions.size() == 0 ? NULL : oracle_actions.begin; 2314 const ptag* cOn = condition_on_names.size() == 0 ? NULL : condition_on_tags.begin; 2315 const char* cNa = NULL; 2316 if (condition_on_names.size() > 0) { 2317 condition_on_names.push_back((char)0); // null terminate 2318 cNa = condition_on_names.begin; 2319 } 2320 const action* alA = (allowed_actions.size() == 0) ? NULL : allowed_actions.begin; 2321 action p = is_ldf ? sch.predictLDF(ec, ec_cnt, my_tag, orA, oracle_actions.size(), cOn, cNa, learner_id) 2322 : sch.predict(*ec, my_tag, orA, oracle_actions.size(), cOn, cNa, alA, allowed_actions.size(), learner_id); 2323 2324 if (condition_on_names.size() > 0) 2325 condition_on_names.pop(); // un-null-terminate 2326 return p; 2327 } 2328 } 2329 2330 // ./vw --search 5 -k -c --search_task sequence -d test_seq --passes 10 -f test_seq.model --holdout_off 2331 // ./vw -i test_seq.model -t -d test_seq --search_beam 2 -p /dev/stdout -r /dev/stdout 2332 2333 // ./vw --search 5 --csoaa_ldf m -k -c --search_task sequence_demoldf -d test_seq --passes 10 -f test_seq.model --holdout_off --search_history_length 0 2334 // ./vw -i test_seq.model -t -d test_seq -p /dev/stdout -r /dev/stdout 2335 2336 // ./vw --search 5 -k -c --search_task sequence -d test_seq --passes 10 -f test_seq.model --holdout_off --search_beam 2 2337 2338 2339 2340 // TODO: raw predictions in LDF mode 2341 2342 // TODO: there's a bug in which if holdout is on, loss isn't computed properly 2343