1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <float.h>
7 #include <string.h>
8 #include <math.h>
9 #include "vw.h"
10 #include "rand48.h"
11 #include "reductions.h"
12 #include "gd.h" // for GD::foreach_feature
13 #include "search_sequencetask.h"
14 #include "search_multiclasstask.h"
15 #include "search_dep_parser.h"
16 #include "search_entityrelationtask.h"
17 #include "search_hooktask.h"
18 #include "search_graph.h"
19 #include "csoaa.h"
20 #include "beam.h"
21 
22 using namespace LEARNER;
23 using namespace std;
24 namespace CS = COST_SENSITIVE;
25 namespace MC = MULTICLASS;
26 
27 namespace Search {
28   search_task* all_tasks[] = { &SequenceTask::task,
29                                &SequenceSpanTask::task,
30                                &ArgmaxTask::task,
31                                &SequenceTask_DemoLDF::task,
32                                &MulticlassTask::task,
33                                &DepParserTask::task,
34                                &EntityRelationTask::task,
35                                &HookTask::task,
36                                &GraphTask::task,
37                                NULL };   // must NULL terminate!
38 
39   const bool PRINT_UPDATE_EVERY_EXAMPLE =0;
40   const bool PRINT_UPDATE_EVERY_PASS =0;
41   const bool PRINT_CLOCK_TIME =0;
42 
43   string   neighbor_feature_space("neighbor");
44   string   condition_feature_space("search_condition");
45 
46   uint32_t AUTO_CONDITION_FEATURES = 1, AUTO_HAMMING_LOSS = 2, EXAMPLES_DONT_CHANGE = 4, IS_LDF = 8;
47   enum SearchState { INITIALIZE, INIT_TEST, INIT_TRAIN, LEARN, GET_TRUTH_STRING };
48   enum RollMethod { POLICY, ORACLE, MIX_PER_STATE, MIX_PER_ROLL, NO_ROLLOUT };
49 
50   // a data structure to hold conditioning information
51   struct prediction {
52     ptag    me;     // the id of the current prediction (the one being memoized)
53     size_t  cnt;    // how many variables are we conditioning on?
54     ptag*   tags;   // which variables are they?
55     action* acts;   // and which actions were taken at each?
56     uint32_t hash;  // a hash of the above
57   };
58 
59   // parameters for auto-conditioning
60   struct auto_condition_settings {
61     size_t max_bias_ngram_length;   // add a "bias" feature for each ngram up to and including this length. eg., if it's 1, then you get a single feature for each conditional
62     size_t max_quad_ngram_length;   // add bias *times* input features for each ngram up to and including this length
63     float  feature_value;           // how much weight should the conditional features get?
64   };
65 
66   typedef v_array<action> action_prefix;
67 
68   struct search_private {
69     vw* all;
70 
71     bool auto_condition_features;  // do you want us to automatically add conditioning features?
72     bool auto_hamming_loss;        // if you're just optimizing hamming loss, we can do it for you!
73     bool examples_dont_change;     // set to true if you don't do any internal example munging
74     bool is_ldf;                   // user declared ldf
75 
76     v_array<int32_t> neighbor_features; // ugly encoding of neighbor feature requirements
77     auto_condition_settings acset; // settings for auto-conditioning
78     size_t history_length;         // value of --search_history_length, used by some tasks, default 1
79 
80     size_t A;                      // total number of actions, [1..A]; 0 means ldf
81     size_t num_learners;           // total number of learners;
82     bool cb_learner;               // do contextual bandit learning on action (was "! rollout_all_actions" which was confusing)
83     SearchState state;             // current state of learning
84     size_t learn_learner_id;       // we allow user to use different learners for different states
85     int mix_per_roll_policy;       // for MIX_PER_ROLL, we need to choose a policy to use; this is where it's stored (-2 means "not selected yet")
86     bool no_caching;               // turn off caching
87     size_t rollout_num_steps;      // how many calls of "loss" before we stop really predicting on rollouts and switch to oracle (0 means "infinite")
88     bool (*label_is_test)(void*);  // tell me if the label data from an example is test
89 
90     size_t t;                      // current search step
91     size_t T;                      // length of root trajectory
92     v_array<example> learn_ec_copy;// copy of example(s) at learn_t
93     example* learn_ec_ref;         // reference to example at learn_t, when there's no example munging
94     size_t learn_ec_ref_cnt;       // how many are there (for LDF mode only; otherwise 1)
95     v_array<ptag> learn_condition_on;      // a copy of the tags used for conditioning at the training position
96     v_array<action> learn_condition_on_act;// the actions taken
97     v_array<char>   learn_condition_on_names;// the names of the actions
98     v_array<action> learn_allowed_actions; // which actions were allowed at training time?
99     v_array<action> ptag_to_action;// tag to action mapping for conditioning
100     vector<action> test_action_sequence; // if test-mode was run, what was the corresponding action sequence; it's a vector cuz we might expose it to the library
101     action learn_oracle_action;    // store an oracle action for debugging purposes
102 
103     polylabel* allowed_actions_cache;
104 
105     size_t loss_declared_cnt;      // how many times did run declare any loss (implicitly or explicitly)?
106     v_array<action> train_trajectory; // the training trajectory
107     v_array<action> current_trajectory;  // the current trajectory; only used in beam search mode
108     size_t learn_t;                // what time step are we learning on?
109     size_t learn_a_idx;            // what action index are we trying?
110     bool done_with_all_actions;    // set to true when there are no more learn_a_idx to go
111 
112     float test_loss;               // loss incurred when run INIT_TEST
113     float learn_loss;              // loss incurred when run LEARN
114     float train_loss;              // loss incurred when run INIT_TRAIN
115 
116     bool last_example_was_newline; // used so we know when a block of examples has passed
117     bool hit_new_pass;             // have we hit a new pass?
118 
119     // if we're printing to stderr we need to remember if we've printed the header yet
120     // (i.e., we do this if we're driving)
121     bool printed_output_header;
122 
123     // various strings for different search states
124     bool should_produce_string;
125     stringstream *pred_string;
126     stringstream *truth_string;
127     stringstream *bad_string_stream;
128 
129     // parameters controlling interpolation
130     float  beta;                   // interpolation rate
131     float  alpha;                  // parameter used to adapt beta for dagger (see above comment), should be in (0,1)
132 
133     RollMethod rollout_method;     // 0=policy, 1=oracle, 2=mix_per_state, 3=mix_per_roll
134     RollMethod rollin_method;
135     float subsample_timesteps;     // train at every time step or just a (random) subset?
136 
137     bool   allow_current_policy;   // should the current policy be used for training? true for dagger
138     bool   adaptive_beta;          // used to implement dagger-like algorithms. if true, beta = 1-(1-alpha)^n after n updates, and policy is mixed with oracle as \pi' = (1-beta)\pi^* + beta \pi
139     size_t passes_per_policy;      // if we're not in dagger-mode, then we need to know how many passes to train a policy
140 
141     uint32_t current_policy;       // what policy are we training right now?
142 
143     // various statistics for reporting
144     size_t num_features;
145     uint32_t total_number_of_policies;
146     size_t read_example_last_id;
147     size_t passes_since_new_policy;
148     size_t read_example_last_pass;
149     size_t total_examples_generated;
150     size_t total_predictions_made;
151     size_t total_cache_hits;
152 
153     vector<example*> ec_seq;  // the collected examples
154     v_hashmap<unsigned char*, action> cache_hash_map;
155 
156     // for foreach_feature temporary storage for conditioning
157     uint32_t dat_new_feature_idx;
158     example* dat_new_feature_ec;
159     stringstream dat_new_feature_audit_ss;
160     size_t dat_new_feature_namespace;
161     string* dat_new_feature_feature_space;
162     float dat_new_feature_value;
163 
164     // to reduce memory allocation
165     string rawOutputString;
166     stringstream* rawOutputStringStream;
167     CS::label ldf_test_label;
168     v_array<action> condition_on_actions;
169     v_array< pair<size_t,size_t> > timesteps;
170     v_array<float> learn_losses;
171 
172     LEARNER::base_learner* base_learner;
173     clock_t start_clock_time;
174 
175     example*empty_example;
176 
177     Beam::beam< action_prefix > *beam;
178     size_t kbest;            // size of kbest list; 1 just means 1best
179     float beam_initial_cost; // when we're doing a subsequent run, how much do we initially pay?
180     action_prefix beam_actions; // on non-initial beam runs, what prefix of actions should we take?
181     float beam_total_cost;
182 
183     search_task* task;    // your task!
184   };
185 
186   string   audit_feature_space("conditional");
187   uint32_t conditional_constant = 8290743;
188 
random_policy(search_private & priv,bool allow_current,bool allow_optimal,bool advance_prng=true)189   int random_policy(search_private& priv, bool allow_current, bool allow_optimal, bool advance_prng=true) {
190     if (priv.beta >= 1) {
191       if (allow_current) return (int)priv.current_policy;
192       if (priv.current_policy > 0) return (((int)priv.current_policy)-1);
193       if (allow_optimal) return -1;
194       std::cerr << "internal error (bug): no valid policies to choose from!  defaulting to current" << std::endl;
195       return (int)priv.current_policy;
196     }
197 
198     int num_valid_policies = (int)priv.current_policy + allow_optimal + allow_current;
199     int pid = -1;
200 
201     if (num_valid_policies == 0) {
202       std::cerr << "internal error (bug): no valid policies to choose from!  defaulting to current" << std::endl;
203       return (int)priv.current_policy;
204     } else if (num_valid_policies == 1)
205       pid = 0;
206     else if (num_valid_policies == 2)
207       pid = (advance_prng ? frand48() : frand48_noadvance()) >= priv.beta;
208     else {
209       // SPEEDUP this up in the case that beta is small!
210       float r = (advance_prng ? frand48() : frand48_noadvance());
211       pid = 0;
212 
213       if (r > priv.beta) {
214         r -= priv.beta;
215         while ((r > 0) && (pid < num_valid_policies-1)) {
216           pid ++;
217           r -= priv.beta * powf(1.f - priv.beta, (float)pid);
218         }
219       }
220     }
221     // figure out which policy pid refers to
222     if (allow_optimal && (pid == num_valid_policies-1))
223       return -1; // this is the optimal policy
224 
225     pid = (int)priv.current_policy - pid;
226     if (!allow_current)
227       pid--;
228 
229     return pid;
230   }
231 
select_learner(search_private & priv,int policy,size_t learner_id)232   int select_learner(search_private& priv, int policy, size_t learner_id) {
233     if (policy<0) return policy;  // optimal policy
234     else          return (int) (policy*priv.num_learners+learner_id);
235   }
236 
237 
should_print_update(vw & all,bool hit_new_pass=false)238   bool should_print_update(vw& all, bool hit_new_pass=false) {
239     //uncomment to print out final loss after all examples processed
240     //commented for now so that outputs matches make test
241     //if( parser_done(all.p)) return true;
242 
243     if (PRINT_UPDATE_EVERY_EXAMPLE) return true;
244     if (PRINT_UPDATE_EVERY_PASS && hit_new_pass) return true;
245     return (all.sd->weighted_examples >= all.sd->dump_interval) && !all.quiet && !all.bfgs;
246   }
247 
248 
might_print_update(vw & all)249   bool might_print_update(vw& all) {
250     // basically do should_print_update but check me and the next
251     // example because of off-by-ones
252 
253     if (PRINT_UPDATE_EVERY_EXAMPLE) return true;
254     if (PRINT_UPDATE_EVERY_PASS) return true;  // SPEEDUP: make this better
255     return (all.sd->weighted_examples + 1. >= all.sd->dump_interval) && !all.quiet && !all.bfgs;
256   }
257 
must_run_test(vw & all,vector<example * > ec,bool is_test_ex)258   bool must_run_test(vw&all, vector<example*>ec, bool is_test_ex) {
259     return
260         (all.final_prediction_sink.size() > 0) ||   // if we have to produce output, we need to run this
261         might_print_update(all) ||                  // if we have to print and update to stderr
262         (all.raw_prediction > 0) ||                 // we need raw predictions
263         // or:
264         //   it's not quiet AND
265         //     current_pass == 0
266         //     OR holdout is off
267         //     OR it's a test example
268         ( //   (! all.quiet) &&  // had to disable this because of library mode!
269           (! is_test_ex) &&
270           ( all.holdout_set_off ||                    // no holdout
271             ec[0]->test_only ||
272             (all.current_pass == 0)                   // we need error rates for progressive cost
273             ) )
274         ;
275   }
276 
clear_seq(vw & all,search_private & priv)277   void clear_seq(vw&all, search_private& priv) {
278     if (priv.ec_seq.size() > 0)
279       for (size_t i=0; i < priv.ec_seq.size(); i++)
280         VW::finish_example(all, priv.ec_seq[i]);
281     priv.ec_seq.clear();
282   }
283 
safediv(float a,float b)284   float safediv(float a,float b) { if (b == 0.f) return 0.f; else return (a/b); }
285 
to_short_string(string in,size_t max_len,char * out)286   void to_short_string(string in, size_t max_len, char*out) {
287     for (size_t i=0; i<max_len; i++)
288       out[i] = ((i >= in.length()) || (in[i] == '\n') || (in[i] == '\t')) ? ' ' : in[i];
289 
290     if (in.length() > max_len) {
291       out[max_len-2] = '.';
292       out[max_len-1] = '.';
293     }
294     out[max_len] = 0;
295   }
296 
number_to_natural(size_t big,char * c)297   void number_to_natural(size_t big, char* c) {
298     if      (big > 9999999999) sprintf(c, "%dg", (int)(big / 1000000000));
299     else if (big >    9999999) sprintf(c, "%dm", (int)(big /    1000000));
300     else if (big >       9999) sprintf(c, "%dk", (int)(big /       1000));
301     else                       sprintf(c, "%d",  (int)(big));
302   }
303 
print_update(search_private & priv)304   void print_update(search_private& priv) {
305     vw& all = *priv.all;
306     if (!priv.printed_output_header && !all.quiet) {
307       const char * header_fmt = "%-10s %-10s %8s%24s %22s %5s %5s  %7s  %7s  %7s  %-8s\n";
308       fprintf(stderr, header_fmt, "average", "since", "instance", "current true",  "current predicted", "cur",  "cur", "predic", "cache", "examples", "");
309       fprintf(stderr, header_fmt, "loss",    "last",  "counter",  "output prefix",  "output prefix",    "pass", "pol", "made",    "hits",  "gener", "beta");
310       std::cerr.precision(5);
311       priv.printed_output_header = true;
312     }
313 
314     if (!should_print_update(all, priv.hit_new_pass))
315       return;
316 
317     char true_label[21];
318     char pred_label[21];
319     to_short_string(priv.truth_string->str(), 20, true_label);
320     to_short_string(priv.pred_string->str() , 20, pred_label);
321 
322     float avg_loss = 0.;
323     float avg_loss_since = 0.;
324     bool use_heldout_loss = (!all.holdout_set_off && all.current_pass >= 1) && (all.sd->weighted_holdout_examples > 0);
325     if (use_heldout_loss) {
326       avg_loss       = safediv((float)all.sd->holdout_sum_loss, (float)all.sd->weighted_holdout_examples);
327       avg_loss_since = safediv((float)all.sd->holdout_sum_loss_since_last_dump, (float)all.sd->weighted_holdout_examples_since_last_dump);
328 
329       all.sd->weighted_holdout_examples_since_last_dump = 0;
330       all.sd->holdout_sum_loss_since_last_dump = 0.0;
331     } else {
332       avg_loss       = safediv((float)all.sd->sum_loss, (float)all.sd->weighted_examples);
333       avg_loss_since = safediv((float)all.sd->sum_loss_since_last_dump, (float) (all.sd->weighted_examples - all.sd->old_weighted_examples));
334     }
335 
336     char inst_cntr[9];  number_to_natural(all.sd->example_number, inst_cntr);
337     char total_pred[8]; number_to_natural(priv.total_predictions_made, total_pred);
338     char total_cach[8]; number_to_natural(priv.total_cache_hits, total_cach);
339     char total_exge[8]; number_to_natural(priv.total_examples_generated, total_exge);
340 
341     fprintf(stderr, "%-10.6f %-10.6f %8s  [%s] [%s] %5d %5d  %7s  %7s  %7s  %-8f",
342             avg_loss,
343             avg_loss_since,
344             inst_cntr,
345             true_label,
346             pred_label,
347             (int)priv.read_example_last_pass,
348             (int)priv.current_policy,
349             total_pred,
350             total_cach,
351             total_exge,
352             priv.beta);
353 
354     if (PRINT_CLOCK_TIME) {
355       size_t num_sec = (size_t)(((float)(clock() - priv.start_clock_time)) / CLOCKS_PER_SEC);
356       fprintf(stderr, " %15lusec", num_sec);
357     }
358 
359     if (use_heldout_loss)
360       fprintf(stderr, " h");
361 
362     fprintf(stderr, "\n");
363     fflush(stderr);
364     all.sd->update_dump_interval(all.progress_add, all.progress_arg);
365   }
366 
add_new_feature(search_private & priv,float val,uint32_t idx)367   void add_new_feature(search_private& priv, float val, uint32_t idx) {
368     size_t mask = priv.all->reg.weight_mask;
369     size_t ss   = priv.all->reg.stride_shift;
370     size_t idx2 = ((idx & mask) >> ss) & mask;
371     feature f = { val * priv.dat_new_feature_value,
372                   (uint32_t) (((priv.dat_new_feature_idx + idx2) << ss) ) };
373     priv.dat_new_feature_ec->atomics[priv.dat_new_feature_namespace].push_back(f);
374     priv.dat_new_feature_ec->sum_feat_sq[priv.dat_new_feature_namespace] += f.x * f.x;
375     if (priv.all->audit) {
376       audit_data a = { NULL, NULL, f.weight_index, f.x, true };
377       a.space   = calloc_or_die<char>(priv.dat_new_feature_feature_space->length()+1);
378       a.feature = calloc_or_die<char>(priv.dat_new_feature_audit_ss.str().length() + 32);
379       strcpy(a.space, priv.dat_new_feature_feature_space->c_str());
380       int num = sprintf(a.feature, "fid=%lu_", (idx & mask) >> ss);
381       strcpy(a.feature+num, priv.dat_new_feature_audit_ss.str().c_str());
382       priv.dat_new_feature_ec->audit_features[priv.dat_new_feature_namespace].push_back(a);
383     }
384   }
385 
del_features_in_top_namespace(search_private & priv,example & ec,size_t ns)386   void del_features_in_top_namespace(search_private& priv, example& ec, size_t ns) {
387     if ((ec.indices.size() == 0) || (ec.indices.last() != ns)) {
388       std::cerr << "internal error (bug): expecting top namespace to be '" << ns << "' but it was ";
389       if (ec.indices.size() == 0) std::cerr << "empty";
390       else std::cerr << (size_t)ec.indices.last();
391       std::cerr << endl;
392       throw exception();
393     }
394     ec.num_features -= ec.atomics[ns].size();
395     ec.total_sum_feat_sq -= ec.sum_feat_sq[ns];
396     ec.sum_feat_sq[ns] = 0;
397     ec.indices.decr();
398     ec.atomics[ns].erase();
399     if (priv.all->audit) {
400       for (size_t i=0; i<ec.audit_features[ns].size(); i++)
401         if (ec.audit_features[ns][i].alloced) {
402           free(ec.audit_features[ns][i].space);
403           free(ec.audit_features[ns][i].feature);
404         }
405       ec.audit_features[ns].erase();
406     }
407   }
408 
add_neighbor_features(search_private & priv)409   void add_neighbor_features(search_private& priv) {
410     vw& all = *priv.all;
411     if (priv.neighbor_features.size() == 0) return;
412 
413     for (size_t n=0; n<priv.ec_seq.size(); n++) {  // iterate over every example in the sequence
414       example& me = *priv.ec_seq[n];
415       for (size_t n_id=0; n_id < priv.neighbor_features.size(); n_id++) {
416         int32_t offset = priv.neighbor_features[n_id] >> 24;
417         size_t  ns     = priv.neighbor_features[n_id] & 0xFF;
418 
419         priv.dat_new_feature_ec = &me;
420         priv.dat_new_feature_value = 1.;
421         priv.dat_new_feature_idx = priv.neighbor_features[n_id] * 13748127;
422         priv.dat_new_feature_namespace = neighbor_namespace;
423         if (priv.all->audit) {
424           priv.dat_new_feature_feature_space = &neighbor_feature_space;
425           priv.dat_new_feature_audit_ss.str("");
426           priv.dat_new_feature_audit_ss << '@' << ((offset > 0) ? '+' : '-') << (char)(abs(offset) + '0');
427           if (ns != ' ') priv.dat_new_feature_audit_ss << (char)ns;
428         }
429 
430         //cerr << "n=" << n << " offset=" << offset << endl;
431         if ((offset < 0) && (n < (uint32_t)(-offset))) // add <s> feature
432           add_new_feature(priv, 1., 925871901 << priv.all->reg.stride_shift);
433         else if (n + offset >= priv.ec_seq.size()) // add </s> feature
434           add_new_feature(priv, 1., 3824917 << priv.all->reg.stride_shift);
435         else { // this is actually a neighbor
436           example& other = *priv.ec_seq[n + offset];
437           GD::foreach_feature<search_private,add_new_feature>(all.reg.weight_vector, all.reg.weight_mask, other.atomics[ns].begin, other.atomics[ns].end, priv, me.ft_offset);
438         }
439       }
440 
441       size_t sz = me.atomics[neighbor_namespace].size();
442       if ((sz > 0) && (me.sum_feat_sq[neighbor_namespace] > 0.)) {
443         me.indices.push_back(neighbor_namespace);
444         me.total_sum_feat_sq += me.sum_feat_sq[neighbor_namespace];
445         me.num_features += sz;
446       } else {
447         me.atomics[neighbor_namespace].erase();
448         if (priv.all->audit) me.audit_features[neighbor_namespace].erase();
449     }
450     }
451   }
452 
del_neighbor_features(search_private & priv)453   void del_neighbor_features(search_private& priv) {
454     if (priv.neighbor_features.size() == 0) return;
455     for (size_t n=0; n<priv.ec_seq.size(); n++)
456       del_features_in_top_namespace(priv, *priv.ec_seq[n], neighbor_namespace);
457   }
458 
reset_search_structure(search_private & priv)459   void reset_search_structure(search_private& priv) {
460     // NOTE: make sure do NOT reset priv.learn_a_idx
461     priv.t = 0;
462     priv.loss_declared_cnt = 0;
463     priv.done_with_all_actions = false;
464     priv.test_loss = 0.;
465     priv.learn_loss = 0.;
466     priv.train_loss = 0.;
467     priv.num_features = 0;
468     priv.should_produce_string = false;
469     priv.mix_per_roll_policy = -2;
470     if (priv.adaptive_beta) {
471       float x = - log1pf(- priv.alpha) * (float)priv.total_examples_generated;
472       static const float log_of_2 = (float)0.6931471805599453;
473       priv.beta = (x <= log_of_2) ? -expm1f(-x) : (1-expf(-x)); // numerical stability
474       //float priv_beta = 1.f - powf(1.f - priv.alpha, (float)priv.total_examples_generated);
475       //assert( fabs(priv_beta - priv.beta) < 1e-2 );
476       if (priv.beta > 1) priv.beta = 1;
477     }
478     priv.ptag_to_action.erase();
479     if (priv.beam)
480       priv.current_trajectory.erase();
481 
482     if (! priv.cb_learner) { // was: if rollout_all_actions
483       uint32_t seed = (uint32_t)(priv.read_example_last_id * 147483 + 4831921) * 2147483647;
484       msrand48(seed);
485     }
486   }
487 
search_declare_loss(search_private & priv,float loss)488   void search_declare_loss(search_private& priv, float loss) {
489     priv.loss_declared_cnt++;
490     switch (priv.state) {
491       case INIT_TEST:  priv.test_loss  += loss; break;
492       case INIT_TRAIN: priv.train_loss += loss; break;
493       case LEARN:
494         if ((priv.rollout_num_steps == 0) || (priv.loss_declared_cnt <= priv.rollout_num_steps))
495           priv.learn_loss += loss;
496         break;
497       default: break; // get rid of the warning about missing cases (danger!)
498     }
499   }
500 
random(size_t max)501   size_t random(size_t max) { return (size_t)(frand48() * (float)max); }
array_contains(T target,const T * A,size_t n)502   template<class T> bool array_contains(T target, const T*A, size_t n) {
503     if (A == NULL) return false;
504     for (size_t i=0; i<n; i++)
505       if (A[i] == target) return true;
506     return false;
507   }
508 
choose_oracle_action(search_private & priv,size_t ec_cnt,const action * oracle_actions,size_t oracle_actions_cnt,const action * allowed_actions,size_t allowed_actions_cnt,bool add_alternatives_to_beam)509   action choose_oracle_action(search_private& priv, size_t ec_cnt, const action* oracle_actions, size_t oracle_actions_cnt, const action* allowed_actions, size_t allowed_actions_cnt, bool add_alternatives_to_beam) {
510     action ret = ( oracle_actions_cnt > 0) ?  oracle_actions[random(oracle_actions_cnt )] :
511                  (allowed_actions_cnt > 0) ? allowed_actions[random(allowed_actions_cnt)] :
512                  priv.is_ldf ? (action)random(ec_cnt) :
513                  (action)(1 + random(ec_cnt));
514     cdbg << "choose_oracle_action from oracle_actions = ["; for (size_t i=0; i<oracle_actions_cnt; i++) cdbg << " " << oracle_actions[i]; cdbg << " ], ret=" << ret << endl;
515     if (add_alternatives_to_beam) {
516       // first, insert all the oracle actions (other than ret)
517       size_t new_len = priv.current_trajectory.size() + 1;
518       if (oracle_actions_cnt > 1)
519         for (size_t i=0; i<oracle_actions_cnt; i++)
520           if (oracle_actions[i] != ret) {
521             float delta_cost = priv.beam_initial_cost + 1e-6f;
522             action_prefix* px = new v_array<action>;
523             px->resize(new_len+1);
524             px->end = px->begin + new_len + 1;
525             memcpy(px->begin, priv.current_trajectory.begin, sizeof(action) * (new_len-1));
526             px->begin[new_len-1] = oracle_actions[i];
527             *((float*)(px->begin+new_len)) = delta_cost;
528             uint32_t px_hash = uniform_hash(px->begin, sizeof(action) * new_len, 3419);
529             //cerr << "insertingA" << endl;
530             if (! priv.beam->insert(px, delta_cost, px_hash)) {
531               px->delete_v();  // SPEEDUP: could be more efficient by reusing for next action
532               delete px;
533             }
534           }
535       // now add all the non-oracle actions (other than ret)
536       size_t top = (allowed_actions_cnt > 0) ? allowed_actions_cnt : ec_cnt;
537       for (size_t i = 0; i < top; i++) {
538         size_t a = (allowed_actions_cnt > 0) ? allowed_actions[i] : i;
539         if (a == ret) continue;
540         if (array_contains<action>((action)a, oracle_actions, oracle_actions_cnt)) continue;
541         float delta_cost = priv.beam_initial_cost + 1.f + 1e-6f;  // TODO: why is this the right cost?
542         action_prefix* px = new v_array<action>;
543         px->resize(new_len + 1);
544         px->end = px->begin + new_len + 1;
545         memcpy(px->begin, priv.current_trajectory.begin, sizeof(action) * (new_len-1));
546         px->begin[new_len-1] = (action)a;
547         *((float*)(px->begin+new_len)) = delta_cost;
548         uint32_t px_hash = uniform_hash(px->begin, sizeof(action) * new_len, 3419);
549         //cerr << "insertingB " << i << " " << allowed_actions_cnt << " " << ec_cnt << " " << top << endl;
550         if (! priv.beam->insert(px, delta_cost, px_hash)) {
551           px->delete_v();  // SPEEDUP: could be more efficient by reusing for next action
552           delete px;
553         }
554       }
555     }
556     return ret;
557   }
558 
add_example_conditioning(search_private & priv,example & ec,const ptag * condition_on,size_t condition_on_cnt,const char * condition_on_names,const action * condition_on_actions)559   void add_example_conditioning(search_private& priv, example& ec, const ptag* condition_on, size_t condition_on_cnt, const char* condition_on_names, const action* condition_on_actions) {
560     if (condition_on_cnt == 0) return;
561 
562     uint32_t extra_offset=0;
563     if (priv.is_ldf)
564       if (ec.l.cs.costs.size() > 0)
565         extra_offset = 3849017 * ec.l.cs.costs[0].class_index;
566 
567     size_t I = condition_on_cnt;
568     size_t N = max(priv.acset.max_bias_ngram_length, priv.acset.max_quad_ngram_length);
569     for (size_t i=0; i<I; i++) { // position in conditioning
570       uint32_t fid = 71933 + 8491087 * extra_offset;
571       if (priv.all->audit) {
572         priv.dat_new_feature_audit_ss.str("");
573         priv.dat_new_feature_audit_ss.clear();
574         priv.dat_new_feature_feature_space = &condition_feature_space;
575       }
576 
577       for (size_t n=0; n<N; n++) { // length of ngram
578         if (i + n >= I) break; // no more ngrams
579         // we're going to add features for the ngram condition_on_actions[i .. i+N]
580         char name = condition_on_names[i+n];
581         fid = fid * 328901 + 71933 * ((condition_on_actions[i+n] + 349101) * (name + 38490137));
582 
583         priv.dat_new_feature_ec  = &ec;
584         priv.dat_new_feature_idx = fid * quadratic_constant;
585         priv.dat_new_feature_namespace = conditioning_namespace;
586         priv.dat_new_feature_value = priv.acset.feature_value;
587 
588         if (priv.all->audit) {
589           if (n > 0) priv.dat_new_feature_audit_ss << ',';
590           if ((33 <= name) && (name <= 126)) priv.dat_new_feature_audit_ss << name;
591           else priv.dat_new_feature_audit_ss << '#' << (int)name;
592           priv.dat_new_feature_audit_ss << '=' << condition_on_actions[i+n];
593         }
594 
595         // add the single bias feature
596         if (n < priv.acset.max_bias_ngram_length)
597           add_new_feature(priv, 1., 4398201 << priv.all->reg.stride_shift);
598 
599         // add the quadratic features
600         if (n < priv.acset.max_quad_ngram_length)
601           GD::foreach_feature<search_private,uint32_t,add_new_feature>(*priv.all, ec, priv);
602       }
603     }
604 
605     size_t sz = ec.atomics[conditioning_namespace].size();
606     if ((sz > 0) && (ec.sum_feat_sq[conditioning_namespace] > 0.)) {
607       ec.indices.push_back(conditioning_namespace);
608       ec.total_sum_feat_sq += ec.sum_feat_sq[conditioning_namespace];
609       ec.num_features += sz;
610     } else {
611       ec.atomics[conditioning_namespace].erase();
612       if (priv.all->audit) ec.audit_features[conditioning_namespace].erase();
613     }
614   }
615 
del_example_conditioning(search_private & priv,example & ec)616   void del_example_conditioning(search_private& priv, example& ec) {
617     if ((ec.indices.size() > 0) && (ec.indices.last() == conditioning_namespace))
618       del_features_in_top_namespace(priv, ec, conditioning_namespace);
619   }
620 
cs_get_costs_size(bool isCB,polylabel & ld)621   size_t cs_get_costs_size(bool isCB, polylabel& ld) {
622     return isCB ? ld.cb.costs.size()
623                 : ld.cs.costs.size();
624   }
625 
cs_get_cost_index(bool isCB,polylabel & ld,size_t k)626   uint32_t cs_get_cost_index(bool isCB, polylabel& ld, size_t k) {
627     return isCB ? ld.cb.costs[k].action
628                 : ld.cs.costs[k].class_index;
629   }
630 
cs_get_cost_partial_prediction(bool isCB,polylabel & ld,size_t k)631   float cs_get_cost_partial_prediction(bool isCB, polylabel& ld, size_t k) {
632     return isCB ? ld.cb.costs[k].partial_prediction
633                 : ld.cs.costs[k].partial_prediction;
634   }
635 
cs_costs_erase(bool isCB,polylabel & ld)636   void cs_costs_erase(bool isCB, polylabel& ld) {
637     if (isCB) ld.cb.costs.erase();
638     else      ld.cs.costs.erase();
639   }
640 
cs_costs_resize(bool isCB,polylabel & ld,size_t new_size)641   void cs_costs_resize(bool isCB, polylabel& ld, size_t new_size) {
642     if (isCB) ld.cb.costs.resize(new_size);
643     else      ld.cs.costs.resize(new_size);
644   }
645 
cs_cost_push_back(bool isCB,polylabel & ld,uint32_t index,float value)646   void cs_cost_push_back(bool isCB, polylabel& ld, uint32_t index, float value) {
647     if (isCB) { CB::cb_class cost = { value, index, 0., 0. }; ld.cb.costs.push_back(cost); }
648     else      { CS::wclass   cost = { value, index, 0., 0. }; ld.cs.costs.push_back(cost); }
649   }
650 
allowed_actions_to_ld(search_private & priv,size_t ec_cnt,const action * allowed_actions,size_t allowed_actions_cnt)651   polylabel& allowed_actions_to_ld(search_private& priv, size_t ec_cnt, const action* allowed_actions, size_t allowed_actions_cnt) {
652     bool isCB = priv.cb_learner;
653     polylabel& ld = *priv.allowed_actions_cache;
654     uint32_t num_costs = (uint32_t)cs_get_costs_size(isCB, ld);
655 
656     if (priv.is_ldf) {  // LDF version easier
657       if (num_costs > ec_cnt)
658         cs_costs_resize(isCB, ld, ec_cnt);
659       else if (num_costs < ec_cnt)
660         for (action k = num_costs; k < ec_cnt; k++)
661           cs_cost_push_back(isCB, ld, k, FLT_MAX);
662 
663     } else { // non-LDF version
664       if ((allowed_actions == NULL) || (allowed_actions_cnt == 0)) { // any action is allowed
665         if (num_costs != priv.A) {  // if there are already A-many actions, they must be the right ones, unless the user did something stupid like putting duplicate allowed_actions...
666           cs_costs_erase(isCB, ld);
667           for (action k = 0; k < priv.A; k++)
668             cs_cost_push_back(isCB, ld, k+1, FLT_MAX);  //+1 because MC is 1-based
669         }
670       } else { // we need to peek at allowed_actions
671         cs_costs_erase(isCB, ld);
672         for (size_t i = 0; i < allowed_actions_cnt; i++)
673           cs_cost_push_back(isCB, ld, allowed_actions[i], FLT_MAX);
674       }
675     }
676 
677     return ld;
678   }
679 
allowed_actions_to_losses(search_private & priv,size_t ec_cnt,const action * allowed_actions,size_t allowed_actions_cnt,const action * oracle_actions,size_t oracle_actions_cnt,v_array<float> & losses)680   void allowed_actions_to_losses(search_private& priv, size_t ec_cnt, const action* allowed_actions, size_t allowed_actions_cnt, const action* oracle_actions, size_t oracle_actions_cnt, v_array<float>& losses) {
681     if (priv.is_ldf)  // LDF version easier
682       for (action k=0; k<ec_cnt; k++)
683         losses.push_back( array_contains<action>(k, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f );
684     else { // non-LDF
685       if ((allowed_actions == NULL) || (allowed_actions_cnt == 0))  // any action is allowed
686         for (action k=1; k<=priv.A; k++)
687           losses.push_back( array_contains<action>(k, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f );
688       else
689         for (size_t i=0; i<allowed_actions_cnt; i++) {
690           action k = allowed_actions[i];
691           losses.push_back( array_contains<action>(k, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f );
692         }
693     }
694   }
695 
single_prediction_notLDF(search_private & priv,example & ec,int policy,const action * allowed_actions,size_t allowed_actions_cnt)696   action single_prediction_notLDF(search_private& priv, example& ec, int policy, const action* allowed_actions, size_t allowed_actions_cnt) {
697     vw& all = *priv.all;
698     polylabel old_label = ec.l;
699     ec.l = allowed_actions_to_ld(priv, 1, allowed_actions, allowed_actions_cnt);
700 
701     priv.base_learner->predict(ec, policy);
702     uint32_t act = ec.pred.multiclass;
703 
704     // in beam search mode, go through alternatives and add them as back-ups
705     if (priv.beam) {
706       float act_cost = 0;
707       size_t K = cs_get_costs_size(priv.cb_learner, ec.l);
708       for (size_t k = 0; k < K; k++)
709         if (cs_get_cost_index(priv.cb_learner, ec.l, k) == act) {
710           act_cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k);
711           break;
712         }
713 
714       priv.beam_total_cost += act_cost;
715       size_t new_len = priv.current_trajectory.size() + 1;
716       for (size_t k = 0; k < K; k++) {
717         action k_act = cs_get_cost_index(priv.cb_learner, ec.l, k);
718         if (k_act == act) continue;  // skip the taken action
719         // TODO: delta_cost is correct for prioritizing, but not for full path cost -- for that, we cannot be subtracting off act_cost
720         float delta_cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k) - act_cost + priv.beam_initial_cost;   // TODO: is delta_cost the right cost?
721         // construct the action prefix
722         action_prefix* px = new v_array<action>;
723 	*px = v_init<action>();
724         px->resize(new_len + 1);
725         px->end = px->begin + new_len + 1;
726         memcpy(px->begin, priv.current_trajectory.begin, sizeof(action) * (new_len-1));
727         px->begin[new_len-1] = k_act;
728         *((float*)(px->begin+new_len)) = delta_cost + act_cost;
729         uint32_t px_hash = uniform_hash(px->begin, sizeof(action) * new_len, 3419);
730         cdbg << "inserting delta_cost=" << delta_cost << " total_cost=" << *((float*)(px->begin+new_len)) << " seq=";
731         for (size_t ii=0; ii<new_len; ii++) cdbg << px->begin[ii] << ' '; cdbg << endl;
732         if (! priv.beam->insert(px, delta_cost, px_hash)) {
733           px->delete_v();  // SPEEDUP: could be more efficient by reusing for next action
734           delete px;
735         }
736       }
737     }
738 
739     // generate raw predictions if necessary
740     if ((priv.state == INIT_TEST) && (all.raw_prediction > 0)) {
741       priv.rawOutputStringStream->str("");
742       for (size_t k = 0; k < cs_get_costs_size(priv.cb_learner, ec.l); k++) {
743         if (k > 0) (*priv.rawOutputStringStream) << ' ';
744         (*priv.rawOutputStringStream) << cs_get_cost_index(priv.cb_learner, ec.l, k) << ':' << cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k);
745       }
746       all.print_text(all.raw_prediction, priv.rawOutputStringStream->str(), ec.tag);
747     }
748 
749     ec.l = old_label;
750 
751     priv.total_predictions_made++;
752     priv.num_features += ec.num_features;
753 
754     return act;
755   }
756 
single_prediction_LDF(search_private & priv,example * ecs,size_t ec_cnt,int policy)757   action single_prediction_LDF(search_private& priv, example* ecs, size_t ec_cnt, int policy) {
758     CS::cs_label.default_label(&priv.ldf_test_label);
759     CS::wclass wc = { 0., 1, 0., 0. };
760     priv.ldf_test_label.costs.push_back(wc);
761 
762     // keep track of best (aka chosen) action
763     float  best_prediction = 0.;
764     action best_action = 0;
765 
766     size_t start_K = (priv.is_ldf && LabelDict::ec_is_example_header(ecs[0])) ? 1 : 0;
767 
768     for (action a= (uint32_t)start_K; a<ec_cnt; a++) {
769       cdbg << "== single_prediction_LDF a=" << a << "==" << endl;
770       if (start_K > 0)
771         LabelDict::add_example_namespaces_from_example(ecs[a], ecs[0]);
772 
773       polylabel old_label = ecs[a].l;
774       ecs[a].l.cs = priv.ldf_test_label;
775       priv.base_learner->predict(ecs[a], policy);
776 
777       priv.empty_example->in_use = true;
778       priv.base_learner->predict(*priv.empty_example);
779 
780       //cerr << "partial_prediction[" << a << "] = " << ecs[a].partial_prediction << endl;
781 
782       if ((a == start_K) || (ecs[a].partial_prediction < best_prediction)) {
783         best_prediction = ecs[a].partial_prediction;
784         best_action     = a;
785       }
786 
787       priv.num_features += ecs[a].num_features;
788       ecs[a].l = old_label;
789       if (start_K > 0)
790         LabelDict::del_example_namespaces_from_example(ecs[a], ecs[0]);
791     }
792 
793     if (priv.beam) {
794       priv.beam_total_cost += best_prediction;
795       size_t new_len = priv.current_trajectory.size() + 1;
796       for (size_t k=start_K; k<ec_cnt; k++) {
797         if (k == best_action) continue;
798         float delta_cost = ecs[k].partial_prediction - best_prediction + priv.beam_initial_cost;
799         action_prefix* px = new v_array<action>;
800 	*px = v_init<action>();
801         px->resize(new_len + 1);
802         px->end = px->begin + new_len + 1;
803         memcpy(px->begin, priv.current_trajectory.begin, sizeof(action) * (new_len-1));
804         px->begin[new_len-1] = (uint32_t)k;  // TODO: k or ld[k]?
805         *((float*)(px->begin+new_len)) = delta_cost + best_prediction;
806         uint32_t px_hash = uniform_hash(px->begin, sizeof(action) * new_len, 3419);
807         if (! priv.beam->insert(px, delta_cost, px_hash)) {
808           px->delete_v();  // SPEEDUP: could be more efficient by reusing for next action
809           delete px;
810         }
811       }
812     }
813 
814     priv.total_predictions_made++;
815     return best_action;
816   }
817 
choose_policy(search_private & priv,bool advance_prng=true)818   int choose_policy(search_private& priv, bool advance_prng=true) {
819     RollMethod method = (priv.state == INIT_TEST ) ? POLICY :
820                         (priv.state == LEARN     ) ? priv.rollout_method :
821                         (priv.state == INIT_TRAIN) ? priv.rollin_method :
822                         NO_ROLLOUT;   // this should never happen
823     switch (method) {
824       case POLICY:
825         return random_policy(priv, priv.allow_current_policy || (priv.state == INIT_TEST), false, advance_prng);
826 
827       case ORACLE:
828         return -1;
829 
830       case MIX_PER_STATE:
831         return random_policy(priv, priv.allow_current_policy, true, advance_prng);
832 
833       case MIX_PER_ROLL:
834         if (priv.mix_per_roll_policy == -2) // then we have to choose one!
835           priv.mix_per_roll_policy = random_policy(priv, priv.allow_current_policy, true, advance_prng);
836         return priv.mix_per_roll_policy;
837 
838     case NO_ROLLOUT:
839     default:
840         std::cerr << "internal error (bug): trying to rollin or rollout with NO_ROLLOUT" << endl;
841         throw exception();
842     }
843   }
844 
cdbg_print_array(string str,v_array<T> & A)845   template<class T> void cdbg_print_array(string str, v_array<T>& A) { cdbg << str << " = ["; for (size_t i=0; i<A.size(); i++) cdbg << " " << A[i]; cdbg << " ]" << endl; }
cerr_print_array(string str,v_array<T> & A)846   template<class T> void cerr_print_array(string str, v_array<T>& A) { cerr << str << " = ["; for (size_t i=0; i<A.size(); i++) cerr << " " << A[i]; cerr << " ]" << endl; }
847 
848   template<class T>
ensure_size(v_array<T> & A,size_t sz)849   void ensure_size(v_array<T>& A, size_t sz) {
850     if ((size_t)(A.end_array - A.begin) < sz)
851       A.resize(sz*2+1, true);
852     A.end = A.begin + sz;
853   }
854 
push_at(v_array<T> & v,T item,size_t pos)855   template<class T> void push_at(v_array<T>& v, T item, size_t pos) {
856     if (v.size() > pos)
857       v.begin[pos] = item;
858     else {
859       if (v.end_array > v.begin + pos) {
860         // there's enough memory, just not enough filler
861         v.begin[pos] = item;
862         v.end = v.begin + pos + 1;
863       } else {
864         // there's not enough memory
865         v.resize(2 * pos + 3, true);
866         v.begin[pos] = item;
867         v.end = v.begin + pos + 1;
868       }
869     }
870   }
871 
record_action(search_private & priv,ptag mytag,action a)872   void record_action(search_private& priv, ptag mytag, action a) {
873     if (mytag == 0) return;
874     push_at(priv.ptag_to_action, a, mytag);
875   }
876 
cached_item_equivalent(unsigned char * & A,unsigned char * & B)877   bool cached_item_equivalent(unsigned char*& A, unsigned char*& B) {
878     size_t sz_A = *A;
879     size_t sz_B = *B;
880     if (sz_A != sz_B) return false;
881     return memcmp(A, B, sz_A) == 0;
882   }
883 
free_key(unsigned char * mem,action a)884   void free_key(unsigned char* mem, action a) { free(mem); }
clear_cache_hash_map(search_private & priv)885   void clear_cache_hash_map(search_private& priv) {
886     priv.cache_hash_map.iter(free_key);
887     priv.cache_hash_map.clear();
888   }
889 
890   // returns true if found and do_store is false. if do_store is true, always returns true.
cached_action_store_or_find(search_private & priv,ptag mytag,const ptag * condition_on,const char * condition_on_names,const action * condition_on_actions,size_t condition_on_cnt,int policy,size_t learner_id,action & a,bool do_store)891   bool cached_action_store_or_find(search_private& priv, ptag mytag, const ptag* condition_on, const char* condition_on_names, const action* condition_on_actions, size_t condition_on_cnt, int policy, size_t learner_id, action &a, bool do_store) {
892     if (priv.no_caching) return do_store;
893     if (mytag == 0) return do_store; // don't attempt to cache when tag is zero
894 
895     size_t sz  = sizeof(size_t) + sizeof(ptag) + sizeof(int) + sizeof(size_t) + sizeof(size_t) + condition_on_cnt * (sizeof(ptag) + sizeof(action) + sizeof(char));
896     if (sz % 4 != 0) sz = 4 * (sz / 4 + 1); // make sure sz aligns to 4 so that uniform_hash does the right thing
897 
898     unsigned char* item = calloc_or_die<unsigned char>(sz);
899     unsigned char* here = item;
900     *here = (unsigned char)sz; here += sizeof(size_t);
901     *here = mytag;             here += sizeof(ptag);
902     *here = policy;            here += sizeof(int);
903     *here = (unsigned char)learner_id;        here += sizeof(size_t);
904     *here = (unsigned char)condition_on_cnt;  here += (unsigned char)sizeof(size_t);
905     for (size_t i=0; i<condition_on_cnt; i++) {
906       *here = condition_on[i];         here += sizeof(ptag);
907       *here = condition_on_actions[i]; here += sizeof(action);
908       *here = condition_on_names[i];   here += sizeof(char);  // SPEEDUP: should we align this at 4?
909     }
910     uint32_t hash = uniform_hash(item, sz, 3419);
911 
912     if (do_store) {
913       priv.cache_hash_map.put(item, hash, a);
914       return true;
915     } else { // its a find
916       a = priv.cache_hash_map.get(item, hash);
917       free(item);
918       return a != (action)-1;
919     }
920   }
921 
generate_training_example(search_private & priv,v_array<float> & losses,bool add_conditioning=true)922   void generate_training_example(search_private& priv, v_array<float>& losses, bool add_conditioning=true) {
923     // should we really subtract out min-loss?
924     float min_loss = FLT_MAX, max_loss = -FLT_MAX;
925     size_t num_min = 0;
926     for (size_t i=0; i<losses.size(); i++) {
927       if (losses[i] < min_loss) { min_loss = losses[i]; num_min = 1; }
928       else if (losses[i] == min_loss) num_min++;
929       if (losses[i] > max_loss) { max_loss = losses[i]; }
930     }
931 
932     int learner = select_learner(priv, priv.current_policy, priv.learn_learner_id);
933 
934     if (!priv.is_ldf) {   // not LDF
935       // since we're not LDF, it should be the case that ec_ref_cnt == 1
936       // and learn_ec_ref[0] is a pointer to a single example
937       assert(priv.learn_ec_ref_cnt == 1);
938       assert(priv.learn_ec_ref != NULL);
939 
940       polylabel labels = allowed_actions_to_ld(priv, priv.learn_ec_ref_cnt, priv.learn_allowed_actions.begin, priv.learn_allowed_actions.size());
941       cdbg_print_array("learn_allowed_actions", priv.learn_allowed_actions);
942       //bool any_gt_1 = false;
943       for (size_t i=0; i<losses.size(); i++) {
944         losses[i] = losses[i] - min_loss;  // TODO: in BEAM mode, subtracting off min_loss seems like a bad idea
945         if (priv.cb_learner) labels.cb.costs[i].cost = losses[i];
946         else                 labels.cs.costs[i].x    = losses[i];
947       }
948 
949       example& ec = priv.learn_ec_ref[0];
950       polylabel old_label = ec.l;
951       ec.l = labels;
952       ec.in_use = true;
953       if (add_conditioning) add_example_conditioning(priv, ec, priv.learn_condition_on.begin, priv.learn_condition_on.size(), priv.learn_condition_on_names.begin, priv.learn_condition_on_act.begin);
954       priv.base_learner->learn(ec, learner);
955       if (add_conditioning) del_example_conditioning(priv, ec);
956       ec.l = old_label;
957       priv.total_examples_generated++;
958     } else {              // is  LDF
959       assert(losses.size() == priv.learn_ec_ref_cnt);
960       size_t start_K = (priv.is_ldf && LabelDict::ec_is_example_header(priv.learn_ec_ref[0])) ? 1 : 0;
961       for (action a= (uint32_t)start_K; a<priv.learn_ec_ref_cnt; a++) {
962         example& ec = priv.learn_ec_ref[a];
963 
964         CS::label& lab = ec.l.cs;
965         if (lab.costs.size() == 0) {
966           CS::wclass wc = { 0., 1, 0., 0. };
967           lab.costs.push_back(wc);
968         }
969         lab.costs[0].x = losses[a] - min_loss;
970         //cerr << "cost[" << a << "] = " << losses[a] << " - " << min_loss << " = " << lab.costs[0].x << endl;
971         ec.in_use = true;
972         if (add_conditioning) add_example_conditioning(priv, ec, priv.learn_condition_on.begin, priv.learn_condition_on.size(), priv.learn_condition_on_names.begin, priv.learn_condition_on_act.begin);
973         priv.base_learner->learn(ec, learner);
974         cdbg << "generate_training_example called learn on action a=" << a << ", costs.size=" << lab.costs.size() << " ec=" << &ec << endl;
975         priv.total_examples_generated++;
976       }
977       priv.base_learner->learn(*priv.empty_example, learner);
978       cdbg << "generate_training_example called learn on empty_example" << endl;
979 
980       for (action a= (uint32_t)start_K; a<priv.learn_ec_ref_cnt; a++) {
981         example& ec = priv.learn_ec_ref[a];
982         if (add_conditioning)
983           del_example_conditioning(priv, ec);
984       }
985     }
986   }
987 
search_predictNeedsExample(search_private & priv)988   bool search_predictNeedsExample(search_private& priv) {
989     // this is basically copied from the logic of search_predict()
990     switch (priv.state) {
991       case INITIALIZE: return false;
992       case GET_TRUTH_STRING: return false;
993       case INIT_TEST:
994         if (priv.beam && (priv.t < priv.beam_actions.size()))
995           return false;
996         return true;
997       case INIT_TRAIN:
998         if (priv.beam && (priv.t < priv.beam_actions.size()))
999           return false;
1000         break;
1001       case LEARN:
1002         if (priv.t < priv.learn_t) return false;
1003         if (priv.t == priv.learn_t) return true;  // SPEEDUP: we really only need it on the last learn_a, but this is hard to know...
1004         // t > priv.learn_t
1005         if ((priv.rollout_num_steps > 0) && (priv.loss_declared_cnt >= priv.rollout_num_steps)) return false; // skipping
1006         break;
1007     }
1008 
1009     int pol = choose_policy(priv, false); // choose a policy but don't advance prng
1010     return (pol != -1);
1011   }
1012 
1013   // note: ec_cnt should be 1 if we are not LDF
search_predict(search_private & priv,example * ecs,size_t ec_cnt,ptag mytag,const action * oracle_actions,size_t oracle_actions_cnt,const ptag * condition_on,const char * condition_on_names,const action * allowed_actions,size_t allowed_actions_cnt,size_t learner_id)1014   action search_predict(search_private& priv, example* ecs, size_t ec_cnt, ptag mytag, const action* oracle_actions, size_t oracle_actions_cnt, const ptag* condition_on, const char* condition_on_names, const action* allowed_actions, size_t allowed_actions_cnt, size_t learner_id) {
1015     size_t condition_on_cnt = condition_on_names ? strlen(condition_on_names) : 0;
1016     size_t t = priv.t;
1017     priv.t++;
1018 
1019     // make sure parameters come in pairs correctly
1020     assert((oracle_actions  == NULL) == (oracle_actions_cnt  == 0));
1021     assert((condition_on    == NULL) == (condition_on_names  == NULL));
1022     assert((allowed_actions == NULL) == (allowed_actions_cnt == 0));
1023 
1024     // if we're just after the string, choose an oracle action
1025     if (priv.state == GET_TRUTH_STRING)
1026       return choose_oracle_action(priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt, false);
1027 
1028     // if we're in LEARN mode and before learn_t, return the train action
1029     if ((priv.state == LEARN) && (t < priv.learn_t)) {
1030       assert(t < priv.train_trajectory.size());
1031       return priv.train_trajectory[t];
1032     }
1033 
1034     if (priv.beam && (t < priv.beam_actions.size()) && ((priv.state == INIT_TEST) || (priv.state == INIT_TRAIN)))
1035       return priv.beam_actions[t];
1036 
1037     // for LDF, # of valid actions is ec_cnt; otherwise it's either allowed_actions_cnt or A
1038     size_t valid_action_cnt = priv.is_ldf ? ec_cnt :
1039                               (allowed_actions_cnt > 0) ? allowed_actions_cnt : priv.A;
1040 
1041     // if we're in LEARN mode and _at_ learn_t, then:
1042     //   - choose the next action
1043     //   - decide if we're done
1044     //   - if we are, then copy/mark the example ref
1045     if ((priv.state == LEARN) && (t == priv.learn_t)) {
1046       action a = (action)priv.learn_a_idx;
1047       priv.loss_declared_cnt = 0;
1048 
1049       priv.learn_a_idx++;
1050       priv.learn_loss = 0.;  // don't include "past cost"
1051 
1052       // check to see if we're done with available actions
1053       if (priv.learn_a_idx >= valid_action_cnt) {
1054         priv.done_with_all_actions = true;
1055         priv.learn_learner_id = learner_id;
1056 
1057         // set reference or copy example(s)
1058         if (oracle_actions_cnt > 0) priv.learn_oracle_action = oracle_actions[0];
1059         priv.learn_ec_ref_cnt = ec_cnt;
1060         if (priv.examples_dont_change)
1061           priv.learn_ec_ref = ecs;
1062         else {
1063           size_t label_size = priv.is_ldf ? sizeof(CS::label) : sizeof(MC::label_t);
1064           void (*label_copy_fn)(void*,void*) = priv.is_ldf ? CS::cs_label.copy_label : NULL;
1065 
1066           ensure_size(priv.learn_ec_copy, ec_cnt);
1067           for (size_t i=0; i<ec_cnt; i++)
1068             VW::copy_example_data(priv.all->audit, priv.learn_ec_copy.begin+i, ecs+i, label_size, label_copy_fn);
1069 
1070           priv.learn_ec_ref = priv.learn_ec_copy.begin;
1071         }
1072 
1073         // copy conditioning stuff and allowed actions
1074         if (priv.auto_condition_features) {
1075           ensure_size(priv.learn_condition_on,     condition_on_cnt);
1076           ensure_size(priv.learn_condition_on_act, condition_on_cnt);
1077 
1078           priv.learn_condition_on.end = priv.learn_condition_on.begin + condition_on_cnt;   // allow .size() to be used in lieu of _cnt
1079 
1080           memcpy(priv.learn_condition_on.begin, condition_on, condition_on_cnt * sizeof(ptag));
1081 
1082           for (size_t i=0; i<condition_on_cnt; i++)
1083             push_at(priv.learn_condition_on_act, ((1 <= condition_on[i]) && (condition_on[i] < priv.ptag_to_action.size())) ? priv.ptag_to_action[condition_on[i]] : 0, i);
1084 
1085           if (condition_on_names == NULL) {
1086             ensure_size(priv.learn_condition_on_names, 1);
1087             priv.learn_condition_on_names[0] = 0;
1088           } else {
1089             ensure_size(priv.learn_condition_on_names, strlen(condition_on_names)+1);
1090             strcpy(priv.learn_condition_on_names.begin, condition_on_names);
1091           }
1092         }
1093 
1094         ensure_size(priv.learn_allowed_actions, allowed_actions_cnt);
1095         memcpy(priv.learn_allowed_actions.begin, allowed_actions, allowed_actions_cnt*sizeof(action));
1096         cdbg_print_array("in LEARN, learn_allowed_actions", priv.learn_allowed_actions);
1097       }
1098 
1099       assert((allowed_actions_cnt == 0) || (a < allowed_actions_cnt));
1100       return (allowed_actions_cnt > 0) ? allowed_actions[a] : priv.is_ldf ? a : (a+1);
1101     }
1102 
1103     if ((priv.state == LEARN) && (t > priv.learn_t) && (priv.rollout_num_steps > 0) && (priv.loss_declared_cnt >= priv.rollout_num_steps)) {
1104       cdbg << "... skipping" << endl;
1105       if (priv.is_ldf) return 0;
1106       else if (allowed_actions_cnt > 0) return allowed_actions[0];
1107       else return 1;
1108     }
1109 
1110 
1111     if ((priv.state == INIT_TRAIN) ||
1112         (priv.state == INIT_TEST) ||
1113         ((priv.state == LEARN) && (t > priv.learn_t))) {
1114       // we actually need to run the policy
1115 
1116       int policy = choose_policy(priv);
1117       action a;
1118 
1119       cdbg << "executing policy " << policy << endl;
1120 
1121       bool gte_here = (priv.state == INIT_TRAIN) && (priv.rollout_method == NO_ROLLOUT) && (oracle_actions_cnt > 0);
1122 
1123       if (policy == -1)
1124         a = choose_oracle_action(priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt, priv.beam && (priv.state != INIT_TEST));
1125 
1126       if ((policy >= 0) || gte_here) {
1127         int learner = select_learner(priv, policy, learner_id);
1128 
1129         ensure_size(priv.condition_on_actions, condition_on_cnt);
1130         for (size_t i=0; i<condition_on_cnt; i++)
1131           priv.condition_on_actions[i] = ((1 <= condition_on[i]) && (condition_on[i] < priv.ptag_to_action.size())) ? priv.ptag_to_action[condition_on[i]] : 0;
1132 
1133         if (cached_action_store_or_find(priv, mytag, condition_on, condition_on_names, priv.condition_on_actions.begin, condition_on_cnt, policy, learner_id, a, false))
1134           // if this succeeded, 'a' has the right action
1135           priv.total_cache_hits++;
1136         else { // we need to predict, and then cache
1137           size_t start_K = (priv.is_ldf && LabelDict::ec_is_example_header(ecs[0])) ? 1 : 0;
1138           if (priv.auto_condition_features)
1139             for (size_t n=start_K; n<ec_cnt; n++)
1140               add_example_conditioning(priv, ecs[n], condition_on, condition_on_cnt, condition_on_names, priv.condition_on_actions.begin);
1141 
1142           if (policy >= 0)   // only make a prediction if we're going to use the output
1143             a = priv.is_ldf ? single_prediction_LDF(priv, ecs, ec_cnt, learner)
1144                             : single_prediction_notLDF(priv, *ecs, learner, allowed_actions, allowed_actions_cnt);
1145 
1146           if (gte_here) {
1147             cdbg << "INIT_TRAIN, NO_ROLLOUT, at least one oracle_actions" << endl;
1148             // we can generate a training example _NOW_ because we're not doing rollouts
1149             v_array<float> losses = v_init<float>(); // SPEEDUP: move this to data structure
1150             allowed_actions_to_losses(priv, ec_cnt, allowed_actions, allowed_actions_cnt, oracle_actions, oracle_actions_cnt, losses);
1151             cdbg_print_array("losses", losses);
1152             priv.learn_ec_ref = ecs;
1153             priv.learn_ec_ref_cnt = ec_cnt;
1154             ensure_size(priv.learn_allowed_actions, allowed_actions_cnt);
1155             memcpy(priv.learn_allowed_actions.begin, allowed_actions, allowed_actions_cnt * sizeof(action));
1156             generate_training_example(priv, losses, false);
1157             losses.delete_v();
1158           }
1159 
1160           if (priv.auto_condition_features)
1161             for (size_t n=start_K; n<ec_cnt; n++)
1162               del_example_conditioning(priv, ecs[n]);
1163 
1164           cached_action_store_or_find(priv, mytag, condition_on, condition_on_names, priv.condition_on_actions.begin, condition_on_cnt, policy, learner_id, a, true);
1165         }
1166       }
1167 
1168       if (priv.state == INIT_TRAIN)
1169         priv.train_trajectory.push_back(a); // note the action for future reference
1170 
1171       return a;
1172     }
1173 
1174     std::cerr << "error: predict called in unknown state" << endl;
1175     throw exception();
1176   }
1177 
cmp_size_t(const size_t a,const size_t b)1178   inline bool cmp_size_t(const size_t a, const size_t b) { return a < b; }
cmp_size_t_pair(const pair<size_t,size_t> & a,const pair<size_t,size_t> & b)1179   inline bool cmp_size_t_pair(const pair<size_t,size_t>& a, const pair<size_t,size_t>& b) { return ((a.first == b.first) && (a.second < b.second)) || (a.first < b.first); }
get_training_timesteps(search_private & priv,v_array<pair<size_t,size_t>> & timesteps)1180   void get_training_timesteps(search_private& priv, v_array< pair<size_t,size_t> >& timesteps) {  // timesteps are pairs of (beam elem, t) where beam elem == 0 means "default" for non-beam search
1181     timesteps.erase();
1182 
1183     // if there's no subsampling to do, just return [0,T)
1184     if (priv.subsample_timesteps <= 0)
1185       for (size_t t=0; t<priv.T; t++)
1186         timesteps.push_back(pair<size_t,size_t>(0,t));
1187 
1188     // if subsample in (0,1) then pick steps with that probability, but ensuring there's at least one!
1189     else if (priv.subsample_timesteps < 1) {
1190       for (size_t t=0; t<priv.T; t++)
1191         if (frand48() <= priv.subsample_timesteps)
1192           timesteps.push_back(pair<size_t,size_t>(0,t));
1193 
1194       if (timesteps.size() == 0) // ensure at least one
1195         timesteps.push_back(pair<size_t,size_t>(0,(size_t)(frand48() * priv.T)));
1196     }
1197 
1198     // finally, if subsample >= 1, then pick (int) that many uniformly at random without replacement; could use an LFSR but why? :P
1199     else {
1200       while ((timesteps.size() < (size_t)priv.subsample_timesteps) &&
1201              (timesteps.size() < priv.T)) {
1202         size_t t = (size_t)(frand48() * (float)priv.T);
1203         if (! v_array_contains(timesteps, pair<size_t,size_t>(0,t)))
1204           timesteps.push_back(pair<size_t,size_t>(0,t));
1205       }
1206       std::sort(timesteps.begin, timesteps.end, cmp_size_t_pair);
1207     }
1208   }
1209 
get_training_timesteps_beam(search_private & priv,Beam::beam<pair<action_prefix *,string>> & final_beam,v_array<pair<size_t,size_t>> & timesteps)1210   void get_training_timesteps_beam(search_private& priv, Beam::beam< pair<action_prefix*, string> >& final_beam, v_array< pair<size_t,size_t> >& timesteps) {
1211     timesteps.erase();
1212     if (priv.subsample_timesteps <= 0) {
1213       size_t id = 0;
1214       for (Beam::beam_element< pair<action_prefix*, string> >* item = final_beam.begin(); item != final_beam.end(); ++item) {
1215         //cerr << id << " " << item->active << " " << item->data->first->size() << endl;
1216         if (item->active)
1217           for (size_t t=0; t<item->data->first->size(); t++)
1218             timesteps.push_back( pair<size_t,size_t>( id, t ) );
1219         id ++;
1220       }
1221     } else {
1222       std::cerr << "error: cannot do subsampling of timesteps with beam search yet!" << endl;
1223       throw exception();
1224     }
1225   }
1226 
free_action_prefix(action_prefix * px)1227   void free_action_prefix(action_prefix* px) {
1228     px->delete_v();
1229     delete px;
1230   }
1231 
free_action_prefix_string_pair(pair<action_prefix *,string> * p)1232   void free_action_prefix_string_pair(pair<action_prefix*,string>* p) {
1233     p->first->delete_v();
1234     delete p->first;
1235     delete p;
1236   }
1237 
final_beam_insert(search_private & priv,Beam::beam<pair<action_prefix *,string>> & beam,float cost,SearchState state)1238   void final_beam_insert(search_private&priv, Beam::beam< pair<action_prefix*,string> >& beam, float cost, SearchState state) {
1239     action_prefix* final = new action_prefix;  // TODO: can we memcpy/push_many?
1240     *final = v_init<action>();
1241     //cerr << "final_beam_insert: cost=" << cost << ", len=" << ((state == INIT_TEST) ? priv.test_action_sequence.size() : priv.train_trajectory.size()) << endl;
1242     if (state == INIT_TEST)
1243       for (size_t i=0; i<priv.test_action_sequence.size(); i++) final->push_back(priv.test_action_sequence[i]);
1244     else if (state == INIT_TRAIN)
1245       for (size_t i=0; i<priv.train_trajectory.size(); i++) final->push_back(priv.train_trajectory[i]);
1246     //cerr << "  --> ["; for (size_t i=0; i<final->size(); i++) cerr << " " << final->get(i); cerr << " ]" << endl;
1247     pair<action_prefix*,string>* p = priv.should_produce_string ? new pair<action_prefix*,string>(final, priv.pred_string->str()) : new pair<action_prefix*,string>(final, "");
1248     uint32_t final_hash = uniform_hash(final->begin, sizeof(action)*final->size(), 3419);
1249     if (!beam.insert(p, cost, final_hash)) {
1250       final->delete_v();
1251       delete final;
1252       delete p;
1253     }
1254   }
1255 
beam_predict(search & sch,SearchState state)1256   Beam::beam< pair<action_prefix*, string> >* beam_predict(search& sch, SearchState state) {
1257     search_private& priv = *sch.priv;
1258     vw&all = *priv.all;
1259     bool old_no_caching = priv.no_caching;   // caching is incompatible with generating beam rollouts
1260     priv.no_caching = true;
1261 
1262     priv.beam->erase(free_action_prefix);
1263     clear_cache_hash_map(priv);
1264 
1265     reset_search_structure(priv);
1266     priv.beam_actions.erase();
1267     priv.state = state;
1268     priv.should_produce_string = (state == INIT_TEST) && (might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0));
1269 
1270     // do the initial prediction
1271     priv.beam_initial_cost = 0.;
1272     priv.beam_total_cost = 0.;
1273     if      (state == INIT_TEST)  priv.test_action_sequence.clear();
1274     else if (state == INIT_TRAIN) priv.train_trajectory.erase();
1275     if (priv.should_produce_string) priv.pred_string->str("");
1276     priv.task->run(sch, priv.ec_seq);
1277     if (all.raw_prediction > 0) all.print_text(all.raw_prediction, "end of initial beam prediction", priv.ec_seq[0]->tag);
1278 
1279     size_t final_size = (state == INIT_TEST) ? max(1, priv.kbest) : max(1, priv.beam->get_beam_size()); // at training time, use beam size
1280     //cerr << "final_size = " << final_size << endl;
1281 
1282     Beam::beam< pair<action_prefix*, string> >* final_beam = new Beam::beam< pair<action_prefix*, string> >(final_size);
1283     final_beam_insert(priv, *final_beam, priv.beam_total_cost, state);
1284 
1285     for (size_t beam_run=1; beam_run<priv.beam->get_beam_size(); beam_run++) {
1286       //cerr << "beam_run=" << beam_run << endl;
1287       priv.beam->compact(free_action_prefix);
1288       Beam::beam_element<action_prefix>* item = priv.beam->pop_best_item();
1289       if (item != NULL) {
1290         reset_search_structure(priv);
1291         priv.beam_actions.erase();
1292         priv.state = state;
1293         priv.should_produce_string = (state == INIT_TEST) && (might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0));
1294         if (priv.should_produce_string) priv.pred_string->str("");
1295         priv.beam_initial_cost = *((float*)(item->data->begin+item->data->size()-1));
1296         priv.beam_total_cost   = priv.beam_initial_cost;
1297         push_many(priv.beam_actions, item->data->begin, item->data->size() - 1);
1298         if      (state == INIT_TEST)  priv.test_action_sequence.clear();
1299         else if (state == INIT_TRAIN) priv.train_trajectory.erase();
1300         priv.task->run(sch, priv.ec_seq);
1301         if (all.raw_prediction > 0) all.print_text(all.raw_prediction, "end of next beam prediction", priv.ec_seq[0]->tag);
1302         final_beam_insert(priv, *final_beam, priv.beam_total_cost, state);
1303       }
1304     }
1305 
1306     final_beam->compact(free_action_prefix_string_pair);
1307     Beam::beam_element< pair<action_prefix*,string> >* best = final_beam->begin();
1308     while ((best != final_beam->end()) && !best->active) ++best;
1309     if (best != final_beam->end()) {
1310       // store in beam_actions the actions for this so that subsequent calls to ->_run() produce it!
1311       priv.beam_actions.erase();
1312       push_many(priv.beam_actions, best->data->first->begin, best->data->first->size());
1313     }
1314 
1315     // TODO: only if test?
1316     if (all.final_prediction_sink.begin != all.final_prediction_sink.end) {  // need to produce prediction output
1317       v_array<char> new_tag = v_init<char>();
1318       for (; best != final_beam->end(); ++best)
1319         if (best->active) {
1320           new_tag.erase();
1321           new_tag.resize(50, true);
1322           int len = sprintf(new_tag.begin, "%-10.6f\t", best->cost);
1323           new_tag.end = new_tag.begin + len;
1324           push_many(new_tag, priv.ec_seq[0]->tag.begin, priv.ec_seq[0]->tag.size());
1325           for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; ++sink)
1326             all.print_text((int)*sink, best->data->second, new_tag);
1327         }
1328       for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; ++sink)
1329         all.print_text((int)*sink, "", priv.ec_seq[0]->tag);
1330       new_tag.delete_v();
1331     }
1332 
1333     priv.no_caching = old_no_caching;
1334     return final_beam;
1335   }
1336 
1337   template <bool is_learn>
train_single_example(search & sch,bool is_test_ex)1338   void train_single_example(search& sch, bool is_test_ex) {
1339     search_private& priv = *sch.priv;
1340     vw&all = *priv.all;
1341     bool ran_test = false;  // we must keep track so that even if we skip test, we still update # of examples seen
1342 
1343     clear_cache_hash_map(priv);
1344     Beam::beam< pair<action_prefix*, string> >* final_beam = NULL;
1345 
1346     // do an initial test pass to compute output (and loss)
1347     if (must_run_test(all, priv.ec_seq, is_test_ex)) {
1348       cdbg << "======================================== INIT TEST (" << priv.current_policy << "," << priv.read_example_last_pass << ") ========================================" << endl;
1349 
1350       ran_test = true;
1351 
1352       if (priv.beam) {
1353         final_beam = beam_predict(sch, INIT_TEST);
1354         final_beam->erase(free_action_prefix_string_pair);
1355         delete final_beam;
1356       }
1357 
1358       // SPEEDUP: in the case of beam, we could skip some of this...
1359       // do the prediction
1360       reset_search_structure(priv);
1361       priv.state = INIT_TEST;
1362       priv.should_produce_string = might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0);
1363       priv.pred_string->str("");
1364       priv.test_action_sequence.clear();
1365       priv.task->run(sch, priv.ec_seq);
1366 
1367       // accumulate loss
1368       if (! is_test_ex)
1369 	all.sd->update(priv.ec_seq[0]->test_only, priv.test_loss, 1.f, priv.num_features);
1370 
1371       // generate output
1372       if (!priv.beam)
1373         for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; ++sink)
1374           all.print_text((int)*sink, priv.pred_string->str(), priv.ec_seq[0]->tag);
1375 
1376       if (all.raw_prediction > 0)
1377         all.print_text(all.raw_prediction, "", priv.ec_seq[0]->tag);
1378     }
1379 
1380     // if we're not training, then we're done!
1381     if ((!is_learn) || is_test_ex || priv.ec_seq[0]->test_only || (!priv.all->training))
1382       return;
1383 
1384     // SPEEDUP: if the oracle was never called, we can skip this!
1385 
1386     // do a pass over the data allowing oracle
1387     cdbg << "======================================== INIT TRAIN (" << priv.current_policy << "," << priv.read_example_last_pass << ") ========================================" << endl;
1388     //cerr << "training" << endl;
1389     final_beam = NULL;
1390     if (priv.beam)
1391       final_beam = beam_predict(sch, INIT_TRAIN);
1392 
1393     reset_search_structure(priv);
1394     priv.state = INIT_TRAIN;
1395     priv.train_trajectory.erase();  // this is where we'll store the training sequence
1396     priv.task->run(sch, priv.ec_seq);
1397 
1398     if (!ran_test) {  // was  && !priv.ec_seq[0]->test_only) { but we know it's not test_only
1399       all.sd->weighted_examples += 1.f;
1400       all.sd->total_features += priv.num_features;
1401       all.sd->sum_loss += priv.test_loss;
1402       all.sd->sum_loss_since_last_dump += priv.test_loss;
1403       all.sd->example_number++;
1404     }
1405 
1406     // if there's nothing to train on, we're done!
1407     if ((priv.loss_declared_cnt == 0) || (priv.t == 0) || (priv.rollout_method == NO_ROLLOUT)) {// TODO: make sure NO_ROLLOUT works with beam!
1408       if (priv.beam) {
1409         final_beam->erase(free_action_prefix_string_pair);
1410         delete final_beam;
1411       }
1412       return;
1413     }
1414 
1415     // otherwise, we have some learn'in to do!
1416     cdbg << "======================================== LEARN (" << priv.current_policy << "," << priv.read_example_last_pass << ") ========================================" << endl;
1417     priv.T = priv.t;
1418     if (priv.beam) get_training_timesteps_beam(priv, *final_beam, priv.timesteps);
1419     else           get_training_timesteps(priv, priv.timesteps);
1420     priv.learn_losses.erase();
1421     //cdbg_print_array("timesteps", priv.timesteps);
1422     size_t last_beam_id = 0;
1423     for (size_t tid=0; tid<priv.timesteps.size(); tid++) {
1424       size_t bid = priv.timesteps[tid].first;
1425       //cerr << "timestep = " << priv.timesteps[tid].first << "." << priv.timesteps[tid].second << " [" << tid << "/" << priv.timesteps.size() << "]" << endl;
1426       if (bid != last_beam_id) {
1427         priv.train_trajectory.erase();
1428         push_many(priv.train_trajectory, final_beam->begin()[bid].data->first->begin, final_beam->begin()[bid].data->first->size());
1429       }
1430 
1431       priv.learn_a_idx = 0;
1432       priv.done_with_all_actions = false;
1433       // for each action, roll out to get a loss
1434       while (! priv.done_with_all_actions) {
1435         reset_search_structure(priv);
1436         priv.beam_actions.erase();
1437         priv.state = LEARN;
1438         priv.learn_t = priv.timesteps[tid].second;
1439         cdbg << "learn_t = " << priv.learn_t << ", learn_a_idx = " << priv.learn_a_idx << endl;
1440         priv.task->run(sch, priv.ec_seq);
1441         priv.learn_losses.push_back( priv.learn_loss );  // SPEEDUP: should we just put this in a CS structure from the get-go?
1442         cdbg_print_array("learn_losses", priv.learn_losses);
1443       }
1444       // now we can make a training example
1445       generate_training_example(priv, priv.learn_losses);
1446       if (! priv.examples_dont_change)
1447         for (size_t n=0; n<priv.learn_ec_copy.size(); n++) {
1448           if (sch.priv->is_ldf) CS::cs_label.delete_label(&priv.learn_ec_copy[n].l.cs);
1449           else                  MC::mc_label.delete_label(&priv.learn_ec_copy[n].l.multi);
1450         }
1451       priv.learn_losses.erase();
1452     }
1453     if (priv.beam) {
1454       final_beam->erase(free_action_prefix_string_pair);
1455       delete final_beam;
1456     }
1457   }
1458 
1459 
1460   template <bool is_learn>
do_actual_learning(vw & all,search & sch)1461   void do_actual_learning(vw&all, search& sch) {
1462     search_private& priv = *sch.priv;
1463 
1464     if (priv.ec_seq.size() == 0)
1465       return;  // nothing to do :)
1466 
1467     bool is_test_ex = false;
1468     for (size_t i=0; i<priv.ec_seq.size(); i++)
1469       if (priv.label_is_test(&priv.ec_seq[i]->l)) { is_test_ex = true; break; }
1470 
1471     if (priv.task->run_setup) priv.task->run_setup(sch, priv.ec_seq);
1472 
1473     // if we're going to have to print to the screen, generate the "truth" string
1474     cdbg << "======================================== GET TRUTH STRING (" << priv.current_policy << "," << priv.read_example_last_pass << ") ========================================" << endl;
1475     if (might_print_update(all)) {
1476       if (is_test_ex)
1477         priv.truth_string->str("**test**");
1478       else {
1479         reset_search_structure(*sch.priv);
1480         priv.beam_actions.erase();
1481         priv.state = GET_TRUTH_STRING;
1482         priv.should_produce_string = true;
1483         priv.truth_string->str("");
1484         priv.task->run(sch, priv.ec_seq);
1485       }
1486     }
1487 
1488     add_neighbor_features(priv);
1489     train_single_example<is_learn>(sch, is_test_ex);
1490     del_neighbor_features(priv);
1491 
1492     if (priv.task->run_takedown) priv.task->run_takedown(sch, priv.ec_seq);
1493   }
1494 
1495   template <bool is_learn>
search_predict_or_learn(search & sch,base_learner & base,example & ec)1496   void search_predict_or_learn(search& sch, base_learner& base, example& ec) {
1497     search_private& priv = *sch.priv;
1498     vw* all = priv.all;
1499     priv.base_learner = &base;
1500     bool is_real_example = true;
1501 
1502     if (example_is_newline(ec) || priv.ec_seq.size() >= all->p->ring_size - 2) {
1503       if (priv.ec_seq.size() >= all->p->ring_size - 2) // -2 to give some wiggle room
1504         std::cerr << "warning: length of sequence at " << ec.example_counter << " exceeds ring size; breaking apart" << std::endl;
1505 
1506       do_actual_learning<is_learn>(*all, sch);
1507 
1508       priv.hit_new_pass = false;
1509       priv.last_example_was_newline = true;
1510       is_real_example = false;
1511     } else {
1512       if (priv.last_example_was_newline)
1513         priv.ec_seq.clear();
1514       priv.ec_seq.push_back(&ec);
1515       priv.last_example_was_newline = false;
1516     }
1517 
1518     if (is_real_example)
1519       priv.read_example_last_id = ec.example_counter;
1520   }
1521 
end_pass(search & sch)1522   void end_pass(search& sch) {
1523     search_private& priv = *sch.priv;
1524     vw* all = priv.all;
1525     priv.hit_new_pass = true;
1526     priv.read_example_last_pass++;
1527     priv.passes_since_new_policy++;
1528 
1529     if (priv.passes_since_new_policy >= priv.passes_per_policy) {
1530       priv.passes_since_new_policy = 0;
1531       if(all->training)
1532         priv.current_policy++;
1533       if (priv.current_policy > priv.total_number_of_policies) {
1534         std::cerr << "internal error (bug): too many policies; not advancing" << std::endl;
1535         priv.current_policy = priv.total_number_of_policies;
1536       }
1537       //reset search_trained_nb_policies in options_from_file so it is saved to regressor file later
1538       std::stringstream ss;
1539       ss << priv.current_policy;
1540       VW::cmd_string_replace_value(all->file_options,"--search_trained_nb_policies", ss.str());
1541     }
1542   }
1543 
finish_example(vw & all,search & sch,example & ec)1544   void finish_example(vw& all, search& sch, example& ec) {
1545     if (ec.end_pass || example_is_newline(ec) || sch.priv->ec_seq.size() >= all.p->ring_size - 2) {
1546       print_update(*sch.priv);
1547       VW::finish_example(all, &ec);
1548       clear_seq(all, *sch.priv);
1549     }
1550   }
1551 
end_examples(search & sch)1552   void end_examples(search& sch) {
1553     search_private& priv = *sch.priv;
1554     vw* all    = priv.all;
1555 
1556     do_actual_learning<true>(*all, sch);
1557 
1558     if( all->training ) {
1559       std::stringstream ss1;
1560       std::stringstream ss2;
1561       ss1 << ((priv.passes_since_new_policy == 0) ? priv.current_policy : (priv.current_policy+1));
1562       //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --search_trained_nb_policies
1563       VW::cmd_string_replace_value(all->file_options,"--search_trained_nb_policies", ss1.str());
1564       ss2 << priv.total_number_of_policies;
1565       //use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --search_total_nb_policies
1566       VW::cmd_string_replace_value(all->file_options,"--search_total_nb_policies", ss2.str());
1567     }
1568   }
1569 
mc_label_is_test(void * lab)1570   bool mc_label_is_test(void* lab) {
1571 	  if (MC::label_is_test((MC::label_t*)lab) > 0)
1572 		  return true;
1573 	  else
1574 		  return false;
1575   }
1576 
search_initialize(vw * all,search & sch)1577   void search_initialize(vw* all, search& sch) {
1578     search_private& priv = *sch.priv;
1579     priv.all = all;
1580 
1581     priv.auto_condition_features = false;
1582     priv.auto_hamming_loss = false;
1583     priv.examples_dont_change = false;
1584     priv.is_ldf = false;
1585 
1586     priv.label_is_test = mc_label_is_test;
1587 
1588     priv.A = 1;
1589     priv.num_learners = 1;
1590     priv.cb_learner = false;
1591     priv.state = INITIALIZE;
1592     priv.learn_learner_id = 0;
1593     priv.mix_per_roll_policy = -2;
1594 
1595     priv.t = 0;
1596     priv.T = 0;
1597     priv.learn_ec_ref = NULL;
1598     priv.learn_ec_ref_cnt = 0;
1599     //priv.allowed_actions_cache = NULL;
1600 
1601     priv.loss_declared_cnt = 0;
1602     priv.learn_t = 0;
1603     priv.learn_a_idx = 0;
1604     priv.done_with_all_actions = false;
1605 
1606     priv.test_loss = 0.;
1607     priv.learn_loss = 0.;
1608     priv.train_loss = 0.;
1609 
1610     priv.last_example_was_newline = false;
1611     priv.hit_new_pass = false;
1612 
1613     priv.printed_output_header = false;
1614 
1615     priv.should_produce_string = false;
1616     priv.pred_string  = new stringstream();
1617     priv.truth_string = new stringstream();
1618     priv.bad_string_stream = new stringstream();
1619     priv.bad_string_stream->clear(priv.bad_string_stream->badbit);
1620 
1621     priv.beta = 0.5;
1622     priv.alpha = 1e-10f;
1623 
1624     priv.rollout_method = MIX_PER_ROLL;
1625     priv.rollin_method  = MIX_PER_ROLL;
1626     priv.subsample_timesteps = 0.;
1627 
1628     priv.allow_current_policy = true;
1629     priv.adaptive_beta = true;
1630     priv.passes_per_policy = 1;     //this should be set to the same value as --passes for dagger
1631 
1632     priv.current_policy = 0;
1633 
1634     priv.num_features = 0;
1635     priv.total_number_of_policies = 1;
1636     priv.read_example_last_id = 0;
1637     priv.passes_per_policy = 0;
1638     priv.read_example_last_pass = 0;
1639     priv.total_examples_generated = 0;
1640     priv.total_predictions_made = 0;
1641     priv.total_cache_hits = 0;
1642 
1643     priv.history_length = 1;
1644     priv.acset.max_bias_ngram_length = 1;
1645     priv.acset.max_quad_ngram_length = 0;
1646     priv.acset.feature_value = 1.;
1647 
1648     priv.cache_hash_map.set_default_value((action)-1);
1649     priv.cache_hash_map.set_equivalent(cached_item_equivalent);
1650 
1651     priv.task = NULL;
1652     sch.task_data = NULL;
1653 
1654     priv.empty_example = alloc_examples(sizeof(CS::label), 1);
1655     CS::cs_label.default_label(&priv.empty_example->l.cs);
1656     priv.empty_example->in_use = true;
1657 
1658     priv.rawOutputStringStream = new stringstream(priv.rawOutputString);
1659   }
1660 
search_finish(search & sch)1661   void search_finish(search& sch) {
1662     search_private& priv = *sch.priv;
1663     cdbg << "search_finish" << endl;
1664 
1665     clear_cache_hash_map(priv);
1666 
1667     delete priv.truth_string;
1668     delete priv.pred_string;
1669     delete priv.bad_string_stream;
1670     priv.neighbor_features.delete_v();
1671     priv.timesteps.delete_v();
1672     priv.learn_losses.delete_v();
1673     priv.condition_on_actions.delete_v();
1674     priv.learn_allowed_actions.delete_v();
1675     priv.ldf_test_label.costs.delete_v();
1676 
1677     if (priv.beam) {
1678       priv.beam->erase(free_action_prefix);
1679       delete priv.beam;
1680     }
1681     priv.beam_actions.delete_v();
1682 
1683     if (priv.cb_learner)
1684       priv.allowed_actions_cache->cb.costs.delete_v();
1685     else
1686       priv.allowed_actions_cache->cs.costs.delete_v();
1687 
1688     priv.train_trajectory.delete_v();
1689     priv.current_trajectory.delete_v();
1690     priv.ptag_to_action.delete_v();
1691 
1692     dealloc_example(CS::cs_label.delete_label, *(priv.empty_example));
1693     free(priv.empty_example);
1694 
1695     priv.ec_seq.clear();
1696 
1697     // destroy copied examples if we needed them
1698     if (! priv.examples_dont_change) {
1699       void (*delete_label)(void*) = priv.is_ldf ? CS::cs_label.delete_label : MC::mc_label.delete_label;
1700       for(example*ec = priv.learn_ec_copy.begin; ec!=priv.learn_ec_copy.end; ++ec)
1701         dealloc_example(delete_label, *ec);
1702       priv.learn_ec_copy.delete_v();
1703     }
1704     priv.learn_condition_on_names.delete_v();
1705     priv.learn_condition_on.delete_v();
1706     priv.learn_condition_on_act.delete_v();
1707 
1708     if (priv.task->finish != NULL) {
1709       priv.task->finish(sch);
1710     }
1711 
1712     free(priv.allowed_actions_cache);
1713     delete priv.rawOutputStringStream;
1714     delete sch.priv;
1715   }
1716 
ensure_param(float & v,float lo,float hi,float def,const char * string)1717   void ensure_param(float &v, float lo, float hi, float def, const char* string) {
1718     if ((v < lo) || (v > hi)) {
1719       std::cerr << string << endl;
1720       v = def;
1721     }
1722   }
1723 
string_equal(string a,string b)1724   bool string_equal(string a, string b) { return a.compare(b) == 0; }
float_equal(float a,float b)1725   bool float_equal(float a, float b) { return fabs(a-b) < 1e-6; }
uint32_equal(uint32_t a,uint32_t b)1726   bool uint32_equal(uint32_t a, uint32_t b) { return a==b; }
size_equal(size_t a,size_t b)1727   bool size_equal(size_t a, size_t b) { return a==b; }
1728 
check_option(T & ret,vw & all,po::variables_map & vm,const char * opt_name,bool default_to_cmdline,bool (* equal)(T,T),const char * mismatch_error_string,const char * required_error_string)1729   template<class T> void check_option(T& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, bool(*equal)(T,T), const char* mismatch_error_string, const char* required_error_string) {
1730     if (vm.count(opt_name)) {
1731       ret = vm[opt_name].as<T>();
1732       *all.file_options << " --" << opt_name << " " << ret;
1733     } else if (strlen(required_error_string)>0) {
1734       std::cerr << required_error_string << endl;
1735       if (! vm.count("help"))
1736         throw exception();
1737     }
1738   }
1739 
check_option(bool & ret,vw & all,po::variables_map & vm,const char * opt_name,bool default_to_cmdline,const char * mismatch_error_string)1740   void check_option(bool& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, const char* mismatch_error_string) {
1741     if (vm.count(opt_name)) {
1742       ret = true;
1743       *all.file_options << " --" << opt_name;
1744     } else
1745       ret = false;
1746   }
1747 
handle_condition_options(vw & vw,auto_condition_settings & acset)1748   void handle_condition_options(vw& vw, auto_condition_settings& acset) {
1749     new_options(vw, "Search Auto-conditioning Options")
1750       ("search_max_bias_ngram_length",   po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional")
1751       ("search_max_quad_ngram_length",   po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)")
1752       ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)");
1753     add_options(vw);
1754 
1755     po::variables_map& vm = vw.vm;
1756 
1757     check_option<size_t>(acset.max_bias_ngram_length, vw, vm, "search_max_bias_ngram_length", false, size_equal,
1758                          "warning: you specified a different value for --search_max_bias_ngram_length than the one loaded from regressor. proceeding with loaded value: ", "");
1759 
1760     check_option<size_t>(acset.max_quad_ngram_length, vw, vm, "search_max_quad_ngram_length", false, size_equal,
1761                          "warning: you specified a different value for --search_max_quad_ngram_length than the one loaded from regressor. proceeding with loaded value: ", "");
1762 
1763     check_option<float> (acset.feature_value, vw, vm, "search_condition_feature_value", false, float_equal,
1764                          "warning: you specified a different value for --search_condition_feature_value than the one loaded from regressor. proceeding with loaded value: ", "");
1765   }
1766 
read_allowed_transitions(action A,const char * filename)1767   v_array<CS::label> read_allowed_transitions(action A, const char* filename) {
1768     FILE *f = fopen(filename, "r");
1769     if (f == NULL) {
1770       std::cerr << "error: could not read file " << filename << " (" << strerror(errno) << "); assuming all transitions are valid" << endl;
1771       throw exception();
1772     }
1773 
1774     bool* bg = (bool*)malloc((A+1)*(A+1) * sizeof(bool));
1775     int rd,from,to,count=0;
1776     while ((rd = fscanf(f, "%d:%d", &from, &to)) > 0) {
1777       if ((from < 0) || (from > (int)A)) { std::cerr << "warning: ignoring transition from " << from << " because it's out of the range [0," << A << "]" << endl; }
1778       if ((to   < 0) || (to   > (int)A)) { std::cerr << "warning: ignoring transition to "   << to   << " because it's out of the range [0," << A << "]" << endl; }
1779       bg[from * (A+1) + to] = true;
1780       count++;
1781     }
1782     fclose(f);
1783 
1784     v_array<CS::label> allowed = v_init<CS::label>();
1785 
1786     for (size_t from=0; from<A; from++) {
1787       v_array<CS::wclass> costs = v_init<CS::wclass>();
1788 
1789       for (size_t to=0; to<A; to++)
1790         if (bg[from * (A+1) + to]) {
1791           CS::wclass c = { FLT_MAX, (action)to, 0., 0. };
1792           costs.push_back(c);
1793         }
1794 
1795       CS::label ld = { costs };
1796       allowed.push_back(ld);
1797     }
1798     free(bg);
1799 
1800     std::cerr << "read " << count << " allowed transitions from " << filename << endl;
1801 
1802     return allowed;
1803   }
1804 
1805 
parse_neighbor_features(string & nf_string,search & sch)1806   void parse_neighbor_features(string& nf_string, search&sch) {
1807     search_private& priv = *sch.priv;
1808     priv.neighbor_features.erase();
1809     size_t len = nf_string.length();
1810     if (len == 0) return;
1811 
1812     char * cstr = new char [len+1];
1813     strcpy(cstr, nf_string.c_str());
1814 
1815     char * p = strtok(cstr, ",");
1816     v_array<substring> cmd = v_init<substring>();
1817     while (p != 0) {
1818       cmd.erase();
1819       substring me = { p, p+strlen(p) };
1820       tokenize(':', me, cmd, true);
1821 
1822       int32_t posn = 0;
1823       char ns = ' ';
1824       if (cmd.size() == 1) {
1825         posn = int_of_substring(cmd[0]);
1826         ns   = ' ';
1827       } else if (cmd.size() == 2) {
1828         posn = int_of_substring(cmd[0]);
1829         ns   = (cmd[1].end > cmd[1].begin) ? cmd[1].begin[0] : ' ';
1830       } else {
1831         std::cerr << "warning: ignoring malformed neighbor specification: '" << p << "'" << endl;
1832       }
1833       int32_t enc = (posn << 24) | (ns & 0xFF);
1834       priv.neighbor_features.push_back(enc);
1835 
1836       p = strtok(NULL, ",");
1837     }
1838     cmd.delete_v();
1839 
1840     delete[] cstr;
1841   }
1842 
setup(vw & all)1843   base_learner* setup(vw&all) {
1844     if (missing_option<size_t, false>(all, "search",
1845 				      "Use learning to search, argument=maximum action id or 0 for LDF"))
1846       return NULL;
1847     new_options(all, "Search Options")
1848       ("search_task",              po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)")
1849       ("search_interpolation",     po::value<string>(), "at what level should interpolation happen? [*data|policy]")
1850       ("search_rollout",           po::value<string>(), "how should rollouts be executed?           [policy|oracle|*mix_per_state|mix_per_roll|none]")
1851       ("search_rollin",            po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]")
1852 
1853       ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy)     [def=1]")
1854       ("search_beta",              po::value<float>(),  "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]")
1855 
1856       ("search_alpha",             po::value<float>(),  "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data)     [def=1e-10]")
1857 
1858       ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained")
1859 
1860       ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file")
1861 
1862       ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]")
1863       ("search_subsample_time",    po::value<float>(),  "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example")
1864       ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them")
1865       ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")")
1866       ("search_history_length",    po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]")
1867 
1868       ("search_no_caching",                             "turn off the built-in caching ability (makes things slower, but technically more safe)")
1869       ("search_beam",              po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)")
1870       ("search_kbest",             po::value<size_t>(), "size of k-best list to produce (must be <= beam size)")
1871       ;
1872     add_options(all);
1873     po::variables_map& vm = all.vm;
1874 
1875     bool has_hook_task = false;
1876     for (size_t i=0; i<all.args.size()-1; i++)
1877       if (all.args[i] == "--search_task" && all.args[i+1] == "hook")
1878         has_hook_task = true;
1879     if (has_hook_task)
1880       for (int i = (int)all.args.size()-2; i >= 0; i--)
1881         if (all.args[i] == "--search_task" && all.args[i+1] != "hook")
1882           all.args.erase(all.args.begin() + i, all.args.begin() + i + 2);
1883 
1884     search& sch = calloc_or_die<search>();
1885     sch.priv = new search_private();
1886     search_initialize(&all, sch);
1887     search_private& priv = *sch.priv;
1888 
1889     std::string task_string;
1890     std::string interpolation_string = "data";
1891     std::string rollout_string = "mix_per_state";
1892     std::string rollin_string = "mix_per_state";
1893 
1894     check_option<string>(task_string, all, vm, "search_task", false, string_equal,
1895                          "warning: specified --search_task different than the one loaded from regressor. using loaded value of: ",
1896                          "error: you must specify a task using --search_task");
1897 
1898     check_option<string>(interpolation_string, all, vm, "search_interpolation", false, string_equal,
1899                          "warning: specified --search_interpolation different than the one loaded from regressor. using loaded value of: ", "");
1900 
1901     if (vm.count("search_passes_per_policy"))       priv.passes_per_policy    = vm["search_passes_per_policy"].as<size_t>();
1902 
1903     if (vm.count("search_alpha"))                   priv.alpha                = vm["search_alpha"            ].as<float>();
1904     if (vm.count("search_beta"))                    priv.beta                 = vm["search_beta"             ].as<float>();
1905 
1906     if (vm.count("search_subsample_time"))          priv.subsample_timesteps  = vm["search_subsample_time"].as<float>();
1907     if (vm.count("search_no_caching"))              priv.no_caching           = true;
1908     if (vm.count("search_rollout_num_steps"))       priv.rollout_num_steps    = vm["search_rollout_num_steps"].as<size_t>();
1909 
1910     if (vm.count("search_beam"))
1911       priv.beam = new Beam::beam<action_prefix>(vm["search_beam"].as<size_t>());  // TODO: pruning, kbest, equivalence testing
1912     else
1913       priv.beam = NULL;
1914 
1915     priv.kbest = 1;
1916     if (vm.count("search_kbest")) {
1917       priv.kbest = max(1, vm["search_kbest"].as<size_t>());
1918       if (priv.kbest > priv.beam->get_beam_size()) {
1919         std::cerr << "warning: kbest set greater than beam size; shrinking back to " << priv.beam->get_beam_size() << endl;
1920         priv.kbest = priv.beam->get_beam_size();
1921       }
1922     }
1923 
1924     priv.A = vm["search"].as<size_t>();
1925 
1926     string neighbor_features_string;
1927     check_option<string>(neighbor_features_string, all, vm, "search_neighbor_features", false, string_equal,
1928                          "warning: you specified a different feature structure with --search_neighbor_features than the one loaded from predictor. using loaded value of: ", "");
1929     parse_neighbor_features(neighbor_features_string, sch);
1930 
1931     if (interpolation_string.compare("data") == 0) { // run as dagger
1932       priv.adaptive_beta = true;
1933       priv.allow_current_policy = true;
1934       priv.passes_per_policy = all.numpasses;
1935       if (priv.current_policy > 1) priv.current_policy = 1;
1936     } else if (interpolation_string.compare("policy") == 0) {
1937     } else {
1938       std::cerr << "error: --search_interpolation must be 'data' or 'policy'" << endl;
1939       throw exception();
1940     }
1941 
1942     if (vm.count("search_rollout")) rollout_string = vm["search_rollout"].as<string>();
1943     if (vm.count("search_rollin" )) rollin_string  = vm["search_rollin" ].as<string>();
1944 
1945     if      (rollout_string.compare("policy") == 0)          priv.rollout_method = POLICY;
1946     else if (rollout_string.compare("oracle") == 0)          priv.rollout_method = ORACLE;
1947     else if (rollout_string.compare("mix_per_state") == 0)   priv.rollout_method = MIX_PER_STATE;
1948     else if (rollout_string.compare("mix_per_roll") == 0)    priv.rollout_method = MIX_PER_ROLL;
1949     else if (rollout_string.compare("none") == 0)          { priv.rollout_method = NO_ROLLOUT; priv.no_caching = true; std::cerr << "no rollout!" << endl; }
1950     else {
1951       std::cerr << "error: --search_rollout must be 'policy', 'oracle', 'mix_per_state', 'mix_per_roll' or 'none'" << endl;
1952       throw exception();
1953     }
1954 
1955     if      (rollin_string.compare("policy") == 0)         priv.rollin_method = POLICY;
1956     else if (rollin_string.compare("oracle") == 0)         priv.rollin_method = ORACLE;
1957     else if (rollin_string.compare("mix_per_state") == 0)  priv.rollin_method = MIX_PER_STATE;
1958     else if (rollin_string.compare("mix_per_roll") == 0)   priv.rollin_method = MIX_PER_ROLL;
1959     else {
1960       std::cerr << "error: --search_rollin must be 'policy', 'oracle', 'mix_per_state' or 'mix_per_roll'" << endl;
1961       throw exception();
1962     }
1963 
1964     check_option<size_t>(priv.A, all, vm, "search", false, size_equal,
1965                          "warning: you specified a different number of actions through --search than the one loaded from predictor. using loaded value of: ", "");
1966 
1967     check_option<size_t>(priv.history_length, all, vm, "search_history_length", false, size_equal,
1968                          "warning: you specified a different history length through --search_history_length than the one loaded from predictor. using loaded value of: ", "");
1969 
1970     //check if the base learner is contextual bandit, in which case, we dont rollout all actions.
1971     priv.allowed_actions_cache = &calloc_or_die<polylabel>();
1972     if (vm.count("cb")) {
1973       priv.cb_learner = true;
1974       CB::cb_label.default_label(priv.allowed_actions_cache);
1975     } else {
1976       priv.cb_learner = false;
1977       CS::cs_label.default_label(priv.allowed_actions_cache);
1978     }
1979 
1980     //if we loaded a regressor with -i option, --search_trained_nb_policies contains the number of trained policies in the file
1981     // and --search_total_nb_policies contains the total number of policies in the file
1982     if (vm.count("search_total_nb_policies"))
1983       priv.total_number_of_policies = (uint32_t)vm["search_total_nb_policies"].as<size_t>();
1984 
1985     ensure_param(priv.beta , 0.0, 1.0, 0.5, "warning: search_beta must be in (0,1); resetting to 0.5");
1986     ensure_param(priv.alpha, 0.0, 1.0, 1e-10f, "warning: search_alpha must be in (0,1); resetting to 1e-10");
1987 
1988     //compute total number of policies we will have at end of training
1989     // we add current_policy for cases where we start from an initial set of policies loaded through -i option
1990     uint32_t tmp_number_of_policies = priv.current_policy;
1991     if( all.training )
1992       tmp_number_of_policies += (int)ceil(((float)all.numpasses) / ((float)priv.passes_per_policy));
1993 
1994     //the user might have specified the number of policies that will eventually be trained through multiple vw calls,
1995     //so only set total_number_of_policies to computed value if it is larger
1996     cdbg << "current_policy=" << priv.current_policy << " tmp_number_of_policies=" << tmp_number_of_policies << " total_number_of_policies=" << priv.total_number_of_policies << endl;
1997     if( tmp_number_of_policies > priv.total_number_of_policies ) {
1998       priv.total_number_of_policies = tmp_number_of_policies;
1999       if( priv.current_policy > 0 ) //we loaded a file but total number of policies didn't match what is needed for training
2000         std::cerr << "warning: you're attempting to train more classifiers than was allocated initially. Likely to cause bad performance." << endl;
2001     }
2002 
2003     //current policy currently points to a new policy we would train
2004     //if we are not training and loaded a bunch of policies for testing, we need to subtract 1 from current policy
2005     //so that we only use those loaded when testing (as run_prediction is called with allow_current to true)
2006     if( !all.training && priv.current_policy > 0 )
2007       priv.current_policy--;
2008 
2009     std::stringstream ss1, ss2;
2010     ss1 << priv.current_policy;           VW::cmd_string_replace_value(all.file_options,"--search_trained_nb_policies", ss1.str());
2011     ss2 << priv.total_number_of_policies; VW::cmd_string_replace_value(all.file_options,"--search_total_nb_policies",   ss2.str());
2012 
2013     cdbg << "search current_policy = " << priv.current_policy << " total_number_of_policies = " << priv.total_number_of_policies << endl;
2014 
2015     if (task_string.compare("list") == 0) {
2016       std::cerr << endl << "available search tasks:" << endl;
2017       for (search_task** mytask = all_tasks; *mytask != NULL; mytask++)
2018         std::cerr << "  " << (*mytask)->task_name << endl;
2019       std::cerr << endl;
2020       exit(0);
2021     }
2022     for (search_task** mytask = all_tasks; *mytask != NULL; mytask++)
2023       if (task_string.compare((*mytask)->task_name) == 0) {
2024         priv.task = *mytask;
2025         sch.task_name = (*mytask)->task_name;
2026         break;
2027       }
2028     if (priv.task == NULL) {
2029       if (! vm.count("help")) {
2030         std::cerr << "fail: unknown task for --search_task '" << task_string << "'; use --search_task list to get a list" << endl;
2031         throw exception();
2032       }
2033     }
2034     all.p->emptylines_separate_examples = true;
2035 
2036     if (count(all.args.begin(), all.args.end(),"--csoaa") == 0
2037 	&& count(all.args.begin(), all.args.end(),"--csoaa_ldf") == 0
2038 	&& count(all.args.begin(), all.args.end(),"--wap_ldf") == 0
2039 	&&  count(all.args.begin(), all.args.end(),"--cb") == 0)
2040       {
2041 	all.args.push_back("--csoaa");
2042 	stringstream ss;
2043 	ss << vm["search"].as<size_t>();
2044 	all.args.push_back(ss.str());
2045       }
2046     base_learner* base = setup_base(all);
2047 
2048     // default to OAA labels unless the task wants to override this (which they can do in initialize)
2049     all.p->lp = MC::mc_label;
2050     if (priv.task)
2051       priv.task->initialize(sch, priv.A, vm);
2052 
2053     if (vm.count("search_allowed_transitions"))     read_allowed_transitions((action)priv.A, vm["search_allowed_transitions"].as<string>().c_str());
2054 
2055     // set up auto-history if they want it
2056     if (priv.auto_condition_features) {
2057       handle_condition_options(all, priv.acset);
2058 
2059       // turn off auto-condition if it's irrelevant
2060       if (((priv.acset.max_bias_ngram_length == 0) && (priv.acset.max_quad_ngram_length == 0)) ||
2061           (priv.acset.feature_value == 0.f)) {
2062         std::cerr << "warning: turning off AUTO_CONDITION_FEATURES because settings make it useless" << endl;
2063         priv.auto_condition_features = false;
2064       }
2065     }
2066 
2067     if (!priv.allow_current_policy) // if we're not dagger
2068       all.check_holdout_every_n_passes = priv.passes_per_policy;
2069 
2070     all.searchstr = &sch;
2071 
2072     priv.start_clock_time = clock();
2073 
2074     learner<search>& l = init_learner(&sch, base,
2075 				      search_predict_or_learn<true>,
2076 				      search_predict_or_learn<false>,
2077 				      priv.total_number_of_policies);
2078     l.set_finish_example(finish_example);
2079     l.set_end_examples(end_examples);
2080     l.set_finish(search_finish);
2081     l.set_end_pass(end_pass);
2082 
2083     return make_base(l);
2084   }
2085 
action_hamming_loss(action a,const action * A,size_t sz)2086   float action_hamming_loss(action a, const action* A, size_t sz) {
2087     if (sz == 0) return 0.;   // latent variables have zero loss
2088     for (size_t i=0; i<sz; i++)
2089       if (a == A[i]) return 0.;
2090     return 1.;
2091   }
2092 
2093   // the interface:
is_ldf()2094   bool search::is_ldf() { return this->priv->is_ldf; }
2095 
predict(example & ec,ptag mytag,const action * oracle_actions,size_t oracle_actions_cnt,const ptag * condition_on,const char * condition_on_names,const action * allowed_actions,size_t allowed_actions_cnt,size_t learner_id)2096   action search::predict(example& ec, ptag mytag, const action* oracle_actions, size_t oracle_actions_cnt, const ptag* condition_on, const char* condition_on_names, const action* allowed_actions, size_t allowed_actions_cnt, size_t learner_id) {
2097     action a = search_predict(*this->priv, &ec, 1, mytag, oracle_actions, oracle_actions_cnt, condition_on, condition_on_names, allowed_actions, allowed_actions_cnt, learner_id);
2098     if (priv->beam) priv->current_trajectory.push_back(a);
2099     if (priv->state == INIT_TEST) priv->test_action_sequence.push_back(a);
2100     if (mytag != 0) push_at(priv->ptag_to_action, a, mytag);
2101     if (this->priv->auto_hamming_loss)
2102       loss(action_hamming_loss(a, oracle_actions, oracle_actions_cnt));
2103     cdbg << "predict returning " << a << endl;
2104     return a;
2105   }
2106 
predictLDF(example * ecs,size_t ec_cnt,ptag mytag,const action * oracle_actions,size_t oracle_actions_cnt,const ptag * condition_on,const char * condition_on_names,size_t learner_id)2107   action search::predictLDF(example* ecs, size_t ec_cnt, ptag mytag, const action* oracle_actions, size_t oracle_actions_cnt, const ptag* condition_on, const char* condition_on_names, size_t learner_id) {
2108     action a = search_predict(*this->priv, ecs, ec_cnt, mytag, oracle_actions, oracle_actions_cnt, condition_on, condition_on_names, NULL, 0, learner_id);
2109     if (priv->beam) priv->current_trajectory.push_back(a);
2110     if (priv->state == INIT_TEST) priv->test_action_sequence.push_back(a);
2111     if ((mytag != 0) && ecs[a].l.cs.costs.size() > 0)
2112       push_at(priv->ptag_to_action, ecs[a].l.cs.costs[0].class_index, mytag);
2113     if (this->priv->auto_hamming_loss)
2114       loss(action_hamming_loss(a, oracle_actions, oracle_actions_cnt));
2115     cdbg << "predict returning " << a << endl;
2116     return a;
2117   }
2118 
loss(float loss)2119   void search::loss(float loss) { search_declare_loss(*this->priv, loss); }
2120 
predictNeedsExample()2121   bool search::predictNeedsExample() { return search_predictNeedsExample(*this->priv); }
2122 
output()2123   stringstream& search::output() {
2124     if      (!this->priv->should_produce_string    ) return *(this->priv->bad_string_stream);
2125     else if ( this->priv->state == GET_TRUTH_STRING) return *(this->priv->truth_string);
2126     else                                             return *(this->priv->pred_string);
2127   }
2128 
set_options(uint32_t opts)2129   void  search::set_options(uint32_t opts) {
2130     if (this->priv->state != INITIALIZE) {
2131       std::cerr << "error: task cannot set options except in initialize function!" << endl;
2132       throw exception();
2133     }
2134     if ((opts & AUTO_CONDITION_FEATURES) != 0) this->priv->auto_condition_features = true;
2135     if ((opts & AUTO_HAMMING_LOSS)       != 0) this->priv->auto_hamming_loss = true;
2136     if ((opts & EXAMPLES_DONT_CHANGE)    != 0) this->priv->examples_dont_change = true;
2137     if ((opts & IS_LDF)                  != 0) this->priv->is_ldf = true;
2138   }
2139 
set_label_parser(label_parser & lp,bool (* is_test)(void *))2140   void search::set_label_parser(label_parser&lp, bool (*is_test)(void*)) {
2141     if (this->priv->state != INITIALIZE) {
2142       std::cerr << "error: task cannot set label parser except in initialize function!" << endl;
2143       throw exception();
2144     }
2145     this->priv->all->p->lp = lp;
2146     this->priv->label_is_test = is_test;
2147   }
2148 
get_test_action_sequence(vector<action> & V)2149   void search::get_test_action_sequence(vector<action>& V) {
2150     V.clear();
2151     for (size_t i=0; i<this->priv->test_action_sequence.size(); i++)
2152       V.push_back(this->priv->test_action_sequence[i]);
2153   }
2154 
2155 
set_num_learners(size_t num_learners)2156   void search::set_num_learners(size_t num_learners) { this->priv->num_learners = num_learners; }
add_program_options(po::variables_map & vw,po::options_description & opts)2157   void search::add_program_options(po::variables_map& vw, po::options_description& opts) { add_options( *this->priv->all, opts ); }
2158 
get_mask()2159   size_t search::get_mask() { return this->priv->all->reg.weight_mask;}
get_stride_shift()2160   size_t search::get_stride_shift() { return this->priv->all->reg.stride_shift;}
get_history_length()2161   uint32_t search::get_history_length() { return (uint32_t)this->priv->history_length; }
2162 
get_vw_pointer_unsafe()2163   vw& search::get_vw_pointer_unsafe() { return *this->priv->all; }
2164 
2165   // predictor implementation
predictor(search & sch,ptag my_tag)2166   predictor::predictor(search& sch, ptag my_tag) : is_ldf(false), my_tag(my_tag), ec(NULL), ec_cnt(0), ec_alloced(false), oracle_is_pointer(false), allowed_is_pointer(false), learner_id(0), sch(sch) {
2167     oracle_actions = v_init<action>();
2168     condition_on_tags = v_init<ptag>();
2169     condition_on_names = v_init<char>();
2170     allowed_actions = v_init<action>();
2171   }
2172 
free_ec()2173   void predictor::free_ec() {
2174     if (ec_alloced) {
2175       if (is_ldf)
2176         for (size_t i=0; i<ec_cnt; i++)
2177           dealloc_example(CS::cs_label.delete_label, ec[i]);
2178       else
2179         dealloc_example(NULL, *ec);
2180       free(ec);
2181     }
2182   }
2183 
~predictor()2184   predictor::~predictor() {
2185     if (! oracle_is_pointer) oracle_actions.delete_v();
2186     if (! allowed_is_pointer) allowed_actions.delete_v();
2187     free_ec();
2188     condition_on_tags.delete_v();
2189     condition_on_names.delete_v();
2190   }
2191 
set_input(example & input_example)2192   predictor& predictor::set_input(example&input_example) {
2193     free_ec();
2194     is_ldf = false;
2195     ec = &input_example;
2196     ec_cnt = 1;
2197     ec_alloced = false;
2198     return *this;
2199   }
2200 
set_input(example * input_example,size_t input_length)2201   predictor& predictor::set_input(example*input_example, size_t input_length) {
2202     free_ec();
2203     is_ldf = true;
2204     ec = input_example;
2205     ec_cnt = input_length;
2206     ec_alloced = false;
2207     return *this;
2208   }
2209 
set_input_length(size_t input_length)2210   void predictor::set_input_length(size_t input_length) {
2211     is_ldf = true;
2212     if (ec_alloced) ec = (example*)realloc(ec, input_length * sizeof(example));
2213     else            ec = calloc_or_die<example>(input_length);
2214     ec_cnt = input_length;
2215     ec_alloced = true;
2216   }
set_input_at(size_t posn,example & ex)2217   void predictor::set_input_at(size_t posn, example&ex) {
2218     if (!ec_alloced) { std::cerr << "call to set_input_at without previous call to set_input_length" << endl; throw exception(); }
2219     if (posn >= ec_cnt) { std::cerr << "call to set_input_at with too large a position" << endl; throw exception(); }
2220     VW::copy_example_data(false, ec+posn, &ex, CS::cs_label.label_size, CS::cs_label.copy_label); // TODO: the false is "audit"
2221   }
2222 
make_new_pointer(v_array<action> & A,size_t new_size)2223   void predictor::make_new_pointer(v_array<action>& A, size_t new_size) {
2224     size_t old_size      = A.size();
2225     action* old_pointer  = A.begin;
2226     A.begin     = calloc_or_die<action>(new_size);
2227     A.end       = A.begin + new_size;
2228     A.end_array = A.end;
2229     memcpy(A.begin, old_pointer, old_size * sizeof(action));
2230   }
2231 
add_to(v_array<action> & A,bool & A_is_ptr,action a,bool clear_first)2232   predictor& predictor::add_to(v_array<action>& A, bool& A_is_ptr, action a, bool clear_first) {
2233     if (A_is_ptr) { // we need to make our own memory
2234       if (clear_first)
2235         A.end = A.begin;
2236       size_t new_size = clear_first ? 1 : (A.size() + 1);
2237       make_new_pointer(A, new_size);
2238       A_is_ptr = false;
2239       A[new_size-1] = a;
2240     } else { // we've already allocated our own memory
2241       if (clear_first) A.erase();
2242       A.push_back(a);
2243     }
2244     return *this;
2245   }
2246 
add_to(v_array<action> & A,bool & A_is_ptr,action * a,size_t action_count,bool clear_first)2247   predictor& predictor::add_to(v_array<action>&A, bool& A_is_ptr, action*a, size_t action_count, bool clear_first) {
2248     size_t old_size = A.size();
2249     if (old_size > 0) {
2250       if (A_is_ptr) { // we need to make our own memory
2251         if (clear_first) {
2252           A.end = A.begin;
2253           old_size = 0;
2254         }
2255         size_t new_size = old_size + action_count;
2256         make_new_pointer(A, new_size);
2257         A_is_ptr = false;
2258         memcpy(A.begin + old_size, a, action_count * sizeof(action));
2259       } else { // we already have our own memory
2260         if (clear_first) A.erase();
2261         push_many<action>(A, a, action_count);
2262       }
2263     } else { // old_size == 0, clear_first is irrelevant
2264       if (! A_is_ptr)
2265         A.delete_v(); // avoid memory leak
2266 
2267       A.begin = a;
2268       A.end   = a + action_count;
2269       A.end_array = A.end;
2270       A_is_ptr = true;
2271     }
2272     return *this;
2273   }
2274 
erase_oracles()2275   predictor& predictor::erase_oracles() { if (oracle_is_pointer) oracle_actions.end = oracle_actions.begin; else oracle_actions.erase(); return *this; }
add_oracle(action a)2276   predictor& predictor::add_oracle(action a) { return add_to(oracle_actions, oracle_is_pointer, a, false); }
add_oracle(action * a,size_t action_count)2277   predictor& predictor::add_oracle(action*a, size_t action_count) { return add_to(oracle_actions, oracle_is_pointer, a, action_count, false); }
add_oracle(v_array<action> & a)2278   predictor& predictor::add_oracle(v_array<action>& a) { return add_to(oracle_actions, oracle_is_pointer, a.begin, a.size(), false); }
2279 
set_oracle(action a)2280   predictor& predictor::set_oracle(action a) { return add_to(oracle_actions, oracle_is_pointer, a, true); }
set_oracle(action * a,size_t action_count)2281   predictor& predictor::set_oracle(action*a, size_t action_count) { return add_to(oracle_actions, oracle_is_pointer, a, action_count, true); }
set_oracle(v_array<action> & a)2282   predictor& predictor::set_oracle(v_array<action>& a) { return add_to(oracle_actions, oracle_is_pointer, a.begin, a.size(), true); }
2283 
erase_alloweds()2284   predictor& predictor::erase_alloweds() { if (allowed_is_pointer) allowed_actions.end = allowed_actions.begin; else allowed_actions.erase(); return *this; }
add_allowed(action a)2285   predictor& predictor::add_allowed(action a) { return add_to(allowed_actions, allowed_is_pointer, a, false); }
add_allowed(action * a,size_t action_count)2286   predictor& predictor::add_allowed(action*a, size_t action_count) { return add_to(allowed_actions, allowed_is_pointer, a, action_count, false); }
add_allowed(v_array<action> & a)2287   predictor& predictor::add_allowed(v_array<action>& a) { return add_to(allowed_actions, allowed_is_pointer, a.begin, a.size(), false); }
2288 
set_allowed(action a)2289   predictor& predictor::set_allowed(action a) { return add_to(allowed_actions, allowed_is_pointer, a, true); }
set_allowed(action * a,size_t action_count)2290   predictor& predictor::set_allowed(action*a, size_t action_count) { return add_to(allowed_actions, allowed_is_pointer, a, action_count, true); }
set_allowed(v_array<action> & a)2291   predictor& predictor::set_allowed(v_array<action>& a) { return add_to(allowed_actions, allowed_is_pointer, a.begin, a.size(), true); }
2292 
add_condition(ptag tag,char name)2293   predictor& predictor::add_condition(ptag tag, char name) { condition_on_tags.push_back(tag); condition_on_names.push_back(name); return *this; }
set_condition(ptag tag,char name)2294   predictor& predictor::set_condition(ptag tag, char name) { condition_on_tags.erase(); condition_on_names.erase(); return add_condition(tag, name); }
2295 
add_condition_range(ptag hi,ptag count,char name0)2296   predictor& predictor::add_condition_range(ptag hi, ptag count, char name0) {
2297     if (count == 0) return *this;
2298     for (ptag i=0; i<count; i++) {
2299       if (i > hi) break;
2300       char name = name0 + i;
2301       condition_on_tags.push_back(hi-i);
2302       condition_on_names.push_back(name);
2303     }
2304     return *this;
2305   }
set_condition_range(ptag hi,ptag count,char name0)2306   predictor& predictor::set_condition_range(ptag hi, ptag count, char name0) { condition_on_tags.erase(); condition_on_names.erase(); return add_condition_range(hi, count, name0); }
2307 
set_learner_id(size_t id)2308   predictor& predictor::set_learner_id(size_t id) { learner_id = id; return *this; }
2309 
set_tag(ptag tag)2310   predictor& predictor::set_tag(ptag tag) { my_tag = tag; return *this; }
2311 
predict()2312   action predictor::predict() {
2313     const action* orA = oracle_actions.size() == 0 ? NULL : oracle_actions.begin;
2314     const ptag*   cOn = condition_on_names.size() == 0 ? NULL : condition_on_tags.begin;
2315     const char*   cNa = NULL;
2316     if (condition_on_names.size() > 0) {
2317       condition_on_names.push_back((char)0);  // null terminate
2318       cNa = condition_on_names.begin;
2319     }
2320     const action* alA = (allowed_actions.size() == 0) ? NULL : allowed_actions.begin;
2321     action p = is_ldf ? sch.predictLDF(ec, ec_cnt, my_tag, orA, oracle_actions.size(), cOn, cNa, learner_id)
2322                       : sch.predict(*ec, my_tag, orA, oracle_actions.size(), cOn, cNa, alA, allowed_actions.size(), learner_id);
2323 
2324     if (condition_on_names.size() > 0)
2325       condition_on_names.pop();  // un-null-terminate
2326     return p;
2327   }
2328 }
2329 
2330 // ./vw --search 5 -k -c --search_task sequence -d test_seq --passes 10 -f test_seq.model --holdout_off
2331 // ./vw -i test_seq.model -t -d test_seq --search_beam 2 -p /dev/stdout -r /dev/stdout
2332 
2333 // ./vw --search 5 --csoaa_ldf m -k -c --search_task sequence_demoldf -d test_seq --passes 10 -f test_seq.model --holdout_off --search_history_length 0
2334 // ./vw -i test_seq.model -t -d test_seq -p /dev/stdout -r /dev/stdout
2335 
2336 // ./vw --search 5 -k -c --search_task sequence -d test_seq --passes 10 -f test_seq.model --holdout_off --search_beam 2
2337 
2338 
2339 
2340 // TODO: raw predictions in LDF mode
2341 
2342 // TODO: there's a bug in which if holdout is on, loss isn't computed properly
2343