1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <stdio.h>
7 #include <float.h>
8 #include <sstream>
9 #include <fstream>
10 
11 #include "parse_regressor.h"
12 #include "parser.h"
13 #include "vw.h"
14 
15 #include "sender.h"
16 #include "nn.h"
17 #include "gd.h"
18 #include "cbify.h"
19 #include "oaa.h"
20 #include "rand48.h"
21 #include "bs.h"
22 #include "topk.h"
23 #include "ect.h"
24 #include "csoaa.h"
25 #include "cb_algs.h"
26 #include "scorer.h"
27 #include "search.h"
28 #include "bfgs.h"
29 #include "lda_core.h"
30 #include "noop.h"
31 #include "print.h"
32 #include "gd_mf.h"
33 #include "mf.h"
34 #include "ftrl_proximal.h"
35 #include "rand48.h"
36 #include "binary.h"
37 #include "lrq.h"
38 #include "autolink.h"
39 #include "log_multi.h"
40 #include "stagewise_poly.h"
41 #include "active.h"
42 #include "kernel_svm.h"
43 #include "parse_example.h"
44 
45 using namespace std;
46 //
47 // Does string end with a certain substring?
48 //
ends_with(string const & fullString,string const & ending)49 bool ends_with(string const &fullString, string const &ending)
50 {
51     if (fullString.length() > ending.length()) {
52         return (fullString.compare(fullString.length() - ending.length(), ending.length(), ending) == 0);
53     } else {
54         return false;
55     }
56 }
57 
valid_ns(char c)58 bool valid_ns(char c)
59 {
60     if (c=='|'||c==':')
61         return false;
62     return true;
63 }
64 
substring_equal(substring & a,substring & b)65 bool substring_equal(substring&a, substring&b) {
66   return (a.end - a.begin == b.end - b.begin) // same length
67       && (strncmp(a.begin, b.begin, a.end - a.begin) == 0);
68 }
69 
parse_dictionary_argument(vw & all,string str)70 void parse_dictionary_argument(vw&all, string str) {
71   if (str.length() == 0) return;
72   // expecting 'namespace:file', for instance 'w:foo.txt'
73   // in the case of just 'foo.txt' it's applied to the default namespace
74 
75   char ns = ' ';
76   const char*s  = str.c_str();
77   if ((str.length() > 3) && (str[1] == ':')) {
78     ns = str[0];
79     s  += 2;
80   }
81 
82   // see if we've already read this dictionary
83   for (size_t id=0; id<all.read_dictionaries.size(); id++)
84     if (strcmp(all.read_dictionaries[id].name, s) == 0) {
85       all.namespace_dictionaries[(size_t)ns].push_back(all.read_dictionaries[id].dict);
86       return;
87     }
88 
89   feature_dict* map = new feature_dict(1023, NULL, substring_equal);
90 
91   // TODO: handle gzipped dictionaries
92   example *ec = alloc_examples(all.p->lp.label_size, 1);
93   ifstream infile(s);
94   size_t def = (size_t)' ';
95   for (string line; getline(infile, line);) {
96     char* c = (char*)line.c_str(); // we're throwing away const, which is dangerous...
97     while (*c == ' ' || *c == '\t') ++c; // skip initial whitespace
98     char* d = c;
99     while (*d != ' ' && *d != '\t' && *d != '\n' && *d != '\0') ++d; // gobble up initial word
100     if (d == c) continue; // no word
101     if (*d != ' ' && *d != '\t') continue; // reached end of line
102     char* word = calloc_or_die<char>(d-c);
103     memcpy(word, c, d-c);
104     substring ss = { word, word + (d - c) };
105     uint32_t hash = uniform_hash( ss.begin, ss.end-ss.begin, quadratic_constant);
106     if (map->get(ss, hash) != NULL) { // don't overwrite old values!
107       free(word);
108       continue;
109     }
110 
111     d--;
112     *d = '|';  // set up for parser::read_line
113     read_line(all, ec, d);
114     // now we just need to grab stuff from the default namespace of ec!
115     if (ec->atomics[def].size() == 0) {
116       free(word);
117       continue;
118     }
119     v_array<feature>* arr = new v_array<feature>;
120     *arr = v_init<feature>();
121     push_many(*arr, ec->atomics[def].begin, ec->atomics[def].size());
122     map->put(ss, hash, arr);
123   }
124   dealloc_example(all.p->lp.delete_label, *ec);
125   free(ec);
126 
127   cerr << "dictionary " << s << " contains " << map->size() << " item" << (map->size() == 1 ? "\n" : "s\n");
128   all.namespace_dictionaries[(size_t)ns].push_back(map);
129   dictionary_info info = { calloc_or_die<char>(strlen(s)+1), map };
130   strcpy(info.name, s);
131   all.read_dictionaries.push_back(info);
132 }
133 
parse_affix_argument(vw & all,string str)134 void parse_affix_argument(vw&all, string str) {
135   if (str.length() == 0) return;
136   char* cstr = calloc_or_die<char>(str.length()+1);
137   strcpy(cstr, str.c_str());
138 
139   char*p = strtok(cstr, ",");
140   while (p != 0) {
141     char*q = p;
142     uint16_t prefix = 1;
143     if (q[0] == '+') { q++; }
144     else if (q[0] == '-') { prefix = 0; q++; }
145     if ((q[0] < '1') || (q[0] > '7')) {
146       cerr << "malformed affix argument (length must be 1..7): " << p << endl;
147       throw exception();
148     }
149     uint16_t len = (uint16_t)(q[0] - '0');
150     uint16_t ns = (uint16_t)' ';  // default namespace
151     if (q[1] != 0) {
152       if (valid_ns(q[1]))
153         ns = (uint16_t)q[1];
154       else {
155         cerr << "malformed affix argument (invalid namespace): " << p << endl;
156         throw exception();
157       }
158       if (q[2] != 0) {
159         cerr << "malformed affix argument (too long): " << p << endl;
160         throw exception();
161       }
162     }
163 
164     uint16_t afx = (len << 1) | (prefix & 0x1);
165     all.affix_features[ns] <<= 4;
166     all.affix_features[ns] |=  afx;
167 
168     p = strtok(NULL, ",");
169   }
170 
171   free(cstr);
172 }
173 
parse_diagnostics(vw & all,int argc)174 void parse_diagnostics(vw& all, int argc)
175 {
176   new_options(all, "Diagnostic options")
177     ("version","Version information")
178     ("audit,a", "print weights of features")
179     ("progress,P", po::value< string >(), "Progress update frequency. int: additive, float: multiplicative")
180     ("quiet", "Don't output disgnostics and progress updates")
181     ("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.");
182   add_options(all);
183 
184   po::variables_map& vm = all.vm;
185 
186   if (vm.count("version")) {
187     /* upon direct query for version -- spit it out to stdout */
188     cout << version.to_string() << "\n";
189     exit(0);
190   }
191 
192   if (vm.count("quiet")) {
193     all.quiet = true;
194     // --quiet wins over --progress
195   } else {
196     if (argc == 1)
197       cerr << "For more information use: vw --help" << endl;
198 
199     all.quiet = false;
200 
201     if (vm.count("progress")) {
202       string progress_str = vm["progress"].as<string>();
203       all.progress_arg = (float)::atof(progress_str.c_str());
204 
205       // --progress interval is dual: either integer or floating-point
206       if (progress_str.find_first_of(".") == string::npos) {
207         // No "." in arg: assume integer -> additive
208         all.progress_add = true;
209         if (all.progress_arg < 1) {
210           cerr    << "warning: additive --progress <int>"
211                   << " can't be < 1: forcing to 1\n";
212           all.progress_arg = 1;
213 
214         }
215         all.sd->dump_interval = all.progress_arg;
216 
217       } else {
218         // A "." in arg: assume floating-point -> multiplicative
219         all.progress_add = false;
220 
221         if (all.progress_arg <= 1.0) {
222           cerr    << "warning: multiplicative --progress <float>: "
223                   << vm["progress"].as<string>()
224                   << " is <= 1.0: adding 1.0\n";
225           all.progress_arg += 1.0;
226 
227         } else if (all.progress_arg > 9.0) {
228           cerr    << "warning: multiplicative --progress <float>"
229                   << " is > 9.0: you probably meant to use an integer\n";
230         }
231         all.sd->dump_interval = 1.0;
232       }
233     }
234   }
235 
236   if (vm.count("audit")){
237     all.audit = true;
238   }
239 }
240 
parse_source(vw & all)241 void parse_source(vw& all)
242 {
243   new_options(all, "Input options")
244     ("data,d", po::value< string >(), "Example Set")
245     ("daemon", "persistent daemon mode on port 26542")
246     ("port", po::value<size_t>(),"port to listen on; use 0 to pick unused port")
247     ("num_children", po::value<size_t>(&(all.num_children)), "number of children for persistent daemon mode")
248     ("pid_file", po::value< string >(), "Write pid file in persistent daemon mode")
249     ("port_file", po::value< string >(), "Write port used in persistent daemon mode")
250     ("cache,c", "Use a cache.  The default is <data>.cache")
251     ("cache_file", po::value< vector<string> >(), "The location(s) of cache_file.")
252     ("kill_cache,k", "do not reuse existing cache: create a new one always")
253     ("compressed", "use gzip format whenever possible. If a cache file is being created, this option creates a compressed cache file. A mixture of raw-text & compressed inputs are supported with autodetection.")
254     ("no_stdin", "do not default to reading from stdin");
255   add_options(all);
256 
257   // Be friendly: if -d was left out, treat positional param as data file
258   po::positional_options_description p;
259   p.add("data", -1);
260 
261   po::parsed_options pos = po::command_line_parser(all.args).
262     style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
263     options(all.opts).positional(p).run();
264   all.vm = po::variables_map();
265   po::store(pos, all.vm);
266   po::variables_map& vm = all.vm;
267 
268   //begin input source
269   if (vm.count("no_stdin"))
270     all.stdin_off = true;
271 
272   if ( (vm.count("total") || vm.count("node") || vm.count("unique_id")) && !(vm.count("total") && vm.count("node") && vm.count("unique_id")) )
273     {
274       cout << "you must specificy unique_id, total, and node if you specify any" << endl;
275       throw exception();
276     }
277 
278   if (vm.count("daemon") || vm.count("pid_file") || (vm.count("port") && !all.active) ) {
279     all.daemon = true;
280 
281     // allow each child to process up to 1e5 connections
282     all.numpasses = (size_t) 1e5;
283   }
284 
285   if (vm.count("compressed"))
286       set_compressed(all.p);
287 
288   if (vm.count("data")) {
289     all.data_filename = vm["data"].as<string>();
290     if (ends_with(all.data_filename, ".gz"))
291       set_compressed(all.p);
292   } else
293     all.data_filename = "";
294 
295   if ((vm.count("cache") || vm.count("cache_file")) && vm.count("invert_hash"))
296     {
297       cout << "invert_hash is incompatible with a cache file.  Use it in single pass mode only." << endl;
298       throw exception();
299     }
300 
301   if(!all.holdout_set_off && (vm.count("output_feature_regularizer_binary") || vm.count("output_feature_regularizer_text")))
302     {
303       all.holdout_set_off = true;
304       cerr<<"Making holdout_set_off=true since output regularizer specified\n";
305     }
306 }
307 
parse_feature_tweaks(vw & all)308 void parse_feature_tweaks(vw& all)
309 {
310   new_options(all, "Feature options")
311     ("hash", po::value< string > (), "how to hash the features. Available options: strings, all")
312     ("ignore", po::value< vector<unsigned char> >(), "ignore namespaces beginning with character <arg>")
313     ("keep", po::value< vector<unsigned char> >(), "keep namespaces beginning with character <arg>")
314     ("bit_precision,b", po::value<size_t>(), "number of bits in the feature table")
315     ("noconstant", "Don't add a constant feature")
316     ("constant,C", po::value<float>(&(all.initial_constant)), "Set initial value of constant")
317     ("ngram", po::value< vector<string> >(), "Generate N grams. To generate N grams for a single namespace 'foo', arg should be fN.")
318     ("skips", po::value< vector<string> >(), "Generate skips in N grams. This in conjunction with the ngram tag can be used to generate generalized n-skip-k-gram. To generate n-skips for a single namespace 'foo', arg should be fN.")
319     ("feature_limit", po::value< vector<string> >(), "limit to N features. To apply to a single namespace 'foo', arg should be fN")
320     ("affix", po::value<string>(), "generate prefixes/suffixes of features; argument '+2a,-3b,+1' means generate 2-char prefixes for namespace a, 3-char suffixes for b and 1 char prefixes for default namespace")
321     ("spelling", po::value< vector<string> >(), "compute spelling features for a give namespace (use '_' for default namespace)")
322     ("dictionary", po::value< vector<string> >(), "read a dictionary for additional features (arg either 'x:file' or just 'file')")
323     ("quadratic,q", po::value< vector<string> > (), "Create and use quadratic features")
324     ("q:", po::value< string >(), ": corresponds to a wildcard for all printable characters")
325     ("cubic", po::value< vector<string> > (),
326      "Create and use cubic features");
327   add_options(all);
328 
329   po::variables_map& vm = all.vm;
330 
331   //feature manipulation
332   string hash_function("strings");
333   if(vm.count("hash"))
334     hash_function = vm["hash"].as<string>();
335   all.p->hasher = getHasher(hash_function);
336 
337   if (vm.count("spelling")) {
338     vector<string> spelling_ns = vm["spelling"].as< vector<string> >();
339     for (size_t id=0; id<spelling_ns.size(); id++)
340       if (spelling_ns[id][0] == '_') all.spelling_features[(unsigned char)' '] = true;
341       else all.spelling_features[(size_t)spelling_ns[id][0]] = true;
342   }
343 
344   if (vm.count("affix")) {
345     parse_affix_argument(all, vm["affix"].as<string>());
346     *all.file_options << " --affix " << vm["affix"].as<string>();
347   }
348 
349   if(vm.count("ngram")){
350     if(vm.count("sort_features"))
351       {
352 	cerr << "ngram is incompatible with sort_features.  " << endl;
353 	throw exception();
354       }
355 
356     all.ngram_strings = vm["ngram"].as< vector<string> >();
357     compile_gram(all.ngram_strings, all.ngram, (char*)"grams", all.quiet);
358   }
359 
360   if(vm.count("skips"))
361     {
362       if(!vm.count("ngram"))
363 	{
364 	  cout << "You can not skip unless ngram is > 1" << endl;
365 	  throw exception();
366 	}
367 
368       all.skip_strings = vm["skips"].as<vector<string> >();
369       compile_gram(all.skip_strings, all.skips, (char*)"skips", all.quiet);
370     }
371 
372   if(vm.count("feature_limit"))
373     {
374       all.limit_strings = vm["feature_limit"].as< vector<string> >();
375       compile_limits(all.limit_strings, all.limit, all.quiet);
376     }
377 
378   if (vm.count("bit_precision"))
379     {
380       uint32_t new_bits = (uint32_t)vm["bit_precision"].as< size_t>();
381       if (all.default_bits == false && new_bits != all.num_bits)
382 	{
383 	  cout << "Number of bits is set to " << new_bits << " and " << all.num_bits << " by argument and model.  That does not work." << endl;
384 	  throw exception();
385 	}
386       all.default_bits = false;
387       all.num_bits = new_bits;
388       if (all.num_bits > min(31, sizeof(size_t)*8 - 3))
389 	{
390 	  cout << "Only " << min(31, sizeof(size_t)*8 - 3) << " or fewer bits allowed.  If this is a serious limit, speak up." << endl;
391 	  throw exception();
392 	}
393     }
394 
395   if (vm.count("quadratic"))
396     {
397       all.pairs = vm["quadratic"].as< vector<string> >();
398       vector<string> newpairs;
399       //string tmp;
400       char printable_start = '!';
401       char printable_end = '~';
402       int valid_ns_size = printable_end - printable_start - 1; //will skip two characters
403 
404       if(!all.quiet)
405         cerr<<"creating quadratic features for pairs: ";
406 
407       for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++){
408         if(!all.quiet){
409           cerr << *i << " ";
410           if (i->length() > 2)
411             cerr << endl << "warning, ignoring characters after the 2nd.\n";
412           if (i->length() < 2) {
413             cerr << endl << "error, quadratic features must involve two sets.\n";
414             throw exception();
415           }
416         }
417         //-q x:
418         if((*i)[0]!=':'&&(*i)[1]==':'){
419           newpairs.reserve(newpairs.size() + valid_ns_size);
420           for (char j=printable_start; j<=printable_end; j++){
421             if(valid_ns(j))
422               newpairs.push_back(string(1,(*i)[0])+j);
423           }
424         }
425         //-q :x
426         else if((*i)[0]==':'&&(*i)[1]!=':'){
427           newpairs.reserve(newpairs.size() + valid_ns_size);
428           for (char j=printable_start; j<=printable_end; j++){
429             if(valid_ns(j)){
430 	      stringstream ss;
431 	      ss << j << (*i)[1];
432 	      newpairs.push_back(ss.str());
433 	    }
434           }
435         }
436         //-q ::
437         else if((*i)[0]==':'&&(*i)[1]==':'){
438 	  cout << "in pair creation" << endl;
439           newpairs.reserve(newpairs.size() + valid_ns_size*valid_ns_size);
440 	  stringstream ss;
441 	  ss << ' ' << ' ';
442 	  newpairs.push_back(ss.str());
443           for (char j=printable_start; j<=printable_end; j++){
444             if(valid_ns(j)){
445               for (char k=printable_start; k<=printable_end; k++){
446                 if(valid_ns(k)){
447 		  stringstream ss;
448                   ss << j << k;
449                   newpairs.push_back(ss.str());
450 		}
451               }
452             }
453           }
454         }
455         else{
456           newpairs.push_back(string(*i));
457         }
458       }
459       newpairs.swap(all.pairs);
460       if(!all.quiet)
461         cerr<<endl;
462     }
463 
464   if (vm.count("cubic"))
465     {
466       all.triples = vm["cubic"].as< vector<string> >();
467       if (!all.quiet)
468 	{
469 	  cerr << "creating cubic features for triples: ";
470 	  for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++) {
471 	    cerr << *i << " ";
472 	    if (i->length() > 3)
473 	      cerr << endl << "warning, ignoring characters after the 3rd.\n";
474 	    if (i->length() < 3) {
475 	      cerr << endl << "error, cubic features must involve three sets.\n";
476 	      throw exception();
477 	    }
478 	  }
479 	  cerr << endl;
480 	}
481     }
482 
483   for (size_t i = 0; i < 256; i++)
484     all.ignore[i] = false;
485   all.ignore_some = false;
486 
487   if (vm.count("ignore"))
488     {
489       all.ignore_some = true;
490 
491       vector<unsigned char> ignore = vm["ignore"].as< vector<unsigned char> >();
492       for (vector<unsigned char>::iterator i = ignore.begin(); i != ignore.end();i++)
493 	{
494 	  all.ignore[*i] = true;
495 	}
496       if (!all.quiet)
497 	{
498 	  cerr << "ignoring namespaces beginning with: ";
499 	  for (vector<unsigned char>::iterator i = ignore.begin(); i != ignore.end();i++)
500 	    cerr << *i << " ";
501 
502 	  cerr << endl;
503 	}
504     }
505 
506   if (vm.count("keep"))
507     {
508       for (size_t i = 0; i < 256; i++)
509         all.ignore[i] = true;
510 
511       all.ignore_some = true;
512 
513       vector<unsigned char> keep = vm["keep"].as< vector<unsigned char> >();
514       for (vector<unsigned char>::iterator i = keep.begin(); i != keep.end();i++)
515 	{
516 	  all.ignore[*i] = false;
517 	}
518       if (!all.quiet)
519 	{
520 	  cerr << "using namespaces beginning with: ";
521 	  for (vector<unsigned char>::iterator i = keep.begin(); i != keep.end();i++)
522 	    cerr << *i << " ";
523 
524 	  cerr << endl;
525 	}
526     }
527 
528   if (vm.count("dictionary")) {
529     vector<string> dictionary_ns = vm["dictionary"].as< vector<string> >();
530     for (size_t id=0; id<dictionary_ns.size(); id++)
531       parse_dictionary_argument(all, dictionary_ns[id]);
532   }
533 
534   if (vm.count("noconstant"))
535     all.add_constant = false;
536 }
537 
parse_example_tweaks(vw & all)538 void parse_example_tweaks(vw& all)
539 {
540   new_options(all, "Example options")
541     ("testonly,t", "Ignore label information and just test")
542     ("holdout_off", "no holdout data in multiple passes")
543     ("holdout_period", po::value<uint32_t>(&(all.holdout_period)), "holdout period for test only, default 10")
544     ("holdout_after", po::value<uint32_t>(&(all.holdout_after)), "holdout after n training examples, default off (disables holdout_period)")
545     ("early_terminate", po::value<size_t>(), "Specify the number of passes tolerated when holdout loss doesn't decrease before early termination, default is 3")
546     ("passes", po::value<size_t>(&(all.numpasses)),"Number of Training Passes")
547     ("initial_pass_length", po::value<size_t>(&(all.pass_length)), "initial number of examples per pass")
548     ("examples", po::value<size_t>(&(all.max_examples)), "number of examples to parse")
549     ("min_prediction", po::value<float>(&(all.sd->min_label)), "Smallest prediction to output")
550     ("max_prediction", po::value<float>(&(all.sd->max_label)), "Largest prediction to output")
551     ("sort_features", "turn this on to disregard order in which features have been defined. This will lead to smaller cache sizes")
552     ("loss_function", po::value<string>()->default_value("squared"), "Specify the loss function to be used, uses squared by default. Currently available ones are squared, classic, hinge, logistic and quantile.")
553     ("quantile_tau", po::value<float>()->default_value(0.5), "Parameter \\tau associated with Quantile loss. Defaults to 0.5")
554     ("l1", po::value<float>(&(all.l1_lambda)), "l_1 lambda")
555     ("l2", po::value<float>(&(all.l2_lambda)), "l_2 lambda");
556   add_options(all);
557 
558   po::variables_map& vm = all.vm;
559   if (vm.count("testonly") || all.eta == 0.)
560     {
561       if (!all.quiet)
562 	cerr << "only testing" << endl;
563       all.training = false;
564       if (all.lda > 0)
565         all.eta = 0;
566     }
567   else
568     all.training = true;
569 
570   if(all.numpasses > 1)
571       all.holdout_set_off = false;
572 
573   if(vm.count("holdout_off"))
574       all.holdout_set_off = true;
575 
576   if(vm.count("sort_features"))
577     all.p->sort_features = true;
578 
579   if (vm.count("min_prediction"))
580     all.sd->min_label = vm["min_prediction"].as<float>();
581   if (vm.count("max_prediction"))
582     all.sd->max_label = vm["max_prediction"].as<float>();
583   if (vm.count("min_prediction") || vm.count("max_prediction") || vm.count("testonly"))
584     all.set_minmax = noop_mm;
585 
586   string loss_function = vm["loss_function"].as<string>();
587   float loss_parameter = 0.0;
588   if(vm.count("quantile_tau"))
589     loss_parameter = vm["quantile_tau"].as<float>();
590 
591   all.loss = getLossFunction(all, loss_function, (float)loss_parameter);
592 
593   if (all.l1_lambda < 0.) {
594     cerr << "l1_lambda should be nonnegative: resetting from " << all.l1_lambda << " to 0" << endl;
595     all.l1_lambda = 0.;
596   }
597   if (all.l2_lambda < 0.) {
598     cerr << "l2_lambda should be nonnegative: resetting from " << all.l2_lambda << " to 0" << endl;
599     all.l2_lambda = 0.;
600   }
601   all.reg_mode += (all.l1_lambda > 0.) ? 1 : 0;
602   all.reg_mode += (all.l2_lambda > 0.) ? 2 : 0;
603   if (!all.quiet)
604     {
605       if (all.reg_mode %2 && !vm.count("bfgs"))
606 	cerr << "using l1 regularization = " << all.l1_lambda << endl;
607       if (all.reg_mode > 1)
608 	cerr << "using l2 regularization = " << all.l2_lambda << endl;
609     }
610 }
611 
parse_output_preds(vw & all)612 void parse_output_preds(vw& all)
613 {
614   new_options(all, "Output options")
615     ("predictions,p", po::value< string >(), "File to output predictions to")
616     ("raw_predictions,r", po::value< string >(), "File to output unnormalized predictions to");
617   add_options(all);
618 
619   po::variables_map& vm = all.vm;
620   if (vm.count("predictions")) {
621     if (!all.quiet)
622       cerr << "predictions = " <<  vm["predictions"].as< string >() << endl;
623     if (strcmp(vm["predictions"].as< string >().c_str(), "stdout") == 0)
624       {
625 	all.final_prediction_sink.push_back((size_t) 1);//stdout
626       }
627     else
628       {
629 	const char* fstr = (vm["predictions"].as< string >().c_str());
630 	int f;
631 #ifdef _WIN32
632 	_sopen_s(&f, fstr, _O_CREAT|_O_WRONLY|_O_BINARY|_O_TRUNC, _SH_DENYWR, _S_IREAD|_S_IWRITE);
633 #else
634 	f = open(fstr, O_CREAT|O_WRONLY|O_LARGEFILE|O_TRUNC,0666);
635 #endif
636 	if (f < 0)
637 	  cerr << "Error opening the predictions file: " << fstr << endl;
638 	all.final_prediction_sink.push_back((size_t) f);
639       }
640   }
641 
642   if (vm.count("raw_predictions")) {
643     if (!all.quiet) {
644       cerr << "raw predictions = " <<  vm["raw_predictions"].as< string >() << endl;
645       if (vm.count("binary"))
646         cerr << "Warning: --raw has no defined value when --binary specified, expect no output" << endl;
647     }
648     if (strcmp(vm["raw_predictions"].as< string >().c_str(), "stdout") == 0)
649       all.raw_prediction = 1;//stdout
650     else
651 	{
652 	  const char* t = vm["raw_predictions"].as< string >().c_str();
653 	  int f;
654 #ifdef _WIN32
655 	  _sopen_s(&f, t, _O_CREAT|_O_WRONLY|_O_BINARY|_O_TRUNC, _SH_DENYWR, _S_IREAD|_S_IWRITE);
656 #else
657 	  f = open(t, O_CREAT|O_WRONLY|O_LARGEFILE|O_TRUNC,0666);
658 #endif
659 	  all.raw_prediction = f;
660 	}
661   }
662 }
663 
parse_output_model(vw & all)664 void parse_output_model(vw& all)
665 {
666   new_options(all, "Output model")
667     ("final_regressor,f", po::value< string >(), "Final regressor")
668     ("readable_model", po::value< string >(), "Output human-readable final regressor with numeric features")
669     ("invert_hash", po::value< string >(), "Output human-readable final regressor with feature names.  Computationally expensive.")
670     ("save_resume", "save extra state so learning can be resumed later with new data")
671     ("save_per_pass", "Save the model after every pass over data")
672     ("output_feature_regularizer_binary", po::value< string >(&(all.per_feature_regularizer_output)), "Per feature regularization output file")
673     ("output_feature_regularizer_text", po::value< string >(&(all.per_feature_regularizer_text)), "Per feature regularization output file, in text");
674   add_options(all);
675 
676   po::variables_map& vm = all.vm;
677   if (vm.count("final_regressor")) {
678     all.final_regressor_name = vm["final_regressor"].as<string>();
679     if (!all.quiet)
680       cerr << "final_regressor = " << vm["final_regressor"].as<string>() << endl;
681   }
682   else
683     all.final_regressor_name = "";
684 
685   if (vm.count("readable_model"))
686     all.text_regressor_name = vm["readable_model"].as<string>();
687 
688   if (vm.count("invert_hash")){
689     all.inv_hash_regressor_name = vm["invert_hash"].as<string>();
690     all.hash_inv = true;
691   }
692 
693   if (vm.count("save_per_pass"))
694     all.save_per_pass = true;
695 
696   if (vm.count("save_resume"))
697     all.save_resume = true;
698 }
699 
load_input_model(vw & all,io_buf & io_temp)700 void load_input_model(vw& all, io_buf& io_temp)
701 {
702   // Need to see if we have to load feature mask first or second.
703   // -i and -mask are from same file, load -i file first so mask can use it
704   if (all.vm.count("feature_mask") && all.vm.count("initial_regressor")
705       && all.vm["feature_mask"].as<string>() == all.vm["initial_regressor"].as< vector<string> >()[0]) {
706     // load rest of regressor
707     all.l->save_load(io_temp, true, false);
708     io_temp.close_file();
709 
710     // set the mask, which will reuse -i file we just loaded
711     parse_mask_regressor_args(all);
712   }
713   else {
714     // load mask first
715     parse_mask_regressor_args(all);
716 
717     // load rest of regressor
718     all.l->save_load(io_temp, true, false);
719     io_temp.close_file();
720   }
721 }
722 
setup_base(vw & all)723 LEARNER::base_learner* setup_base(vw& all)
724 {
725   LEARNER::base_learner* ret = all.reduction_stack.pop()(all);
726   if (ret == NULL)
727     return setup_base(all);
728   else
729     return ret;
730 }
731 
parse_reductions(vw & all)732 void parse_reductions(vw& all)
733 {
734   new_options(all, "Reduction options, use [option] --help for more info");
735   add_options(all);
736   //Base algorithms
737   all.reduction_stack.push_back(GD::setup);
738   all.reduction_stack.push_back(kernel_svm_setup);
739   all.reduction_stack.push_back(ftrl_setup);
740   all.reduction_stack.push_back(sender_setup);
741   all.reduction_stack.push_back(gd_mf_setup);
742   all.reduction_stack.push_back(print_setup);
743   all.reduction_stack.push_back(noop_setup);
744   all.reduction_stack.push_back(lda_setup);
745   all.reduction_stack.push_back(bfgs_setup);
746 
747   //Score Users
748   all.reduction_stack.push_back(active_setup);
749   all.reduction_stack.push_back(nn_setup);
750   all.reduction_stack.push_back(mf_setup);
751   all.reduction_stack.push_back(autolink_setup);
752   all.reduction_stack.push_back(lrq_setup);
753   all.reduction_stack.push_back(stagewise_poly_setup);
754   all.reduction_stack.push_back(scorer_setup);
755 
756   //Reductions
757   all.reduction_stack.push_back(binary_setup);
758   all.reduction_stack.push_back(topk_setup);
759   all.reduction_stack.push_back(oaa_setup);
760   all.reduction_stack.push_back(ect_setup);
761   all.reduction_stack.push_back(log_multi_setup);
762   all.reduction_stack.push_back(csoaa_setup);
763   all.reduction_stack.push_back(csldf_setup);
764   all.reduction_stack.push_back(cb_algs_setup);
765   all.reduction_stack.push_back(cbify_setup);
766   all.reduction_stack.push_back(Search::setup);
767   all.reduction_stack.push_back(bs_setup);
768 
769   all.l = setup_base(all);
770 }
771 
add_to_args(vw & all,int argc,char * argv[])772 void add_to_args(vw& all, int argc, char* argv[])
773 {
774   for (int i = 1; i < argc; i++)
775     all.args.push_back(string(argv[i]));
776 }
777 
parse_args(int argc,char * argv[])778 vw& parse_args(int argc, char *argv[])
779 {
780   vw& all = *(new vw());
781 
782   add_to_args(all, argc, argv);
783 
784   size_t random_seed = 0;
785   all.program_name = argv[0];
786 
787   new_options(all, "VW options")
788     ("random_seed", po::value<size_t>(&random_seed), "seed random number generator")
789     ("ring_size", po::value<size_t>(&(all.p->ring_size)), "size of example ring");
790   add_options(all);
791 
792   new_options(all, "Update options")
793     ("learning_rate,l", po::value<float>(&(all.eta)), "Set learning rate")
794     ("power_t", po::value<float>(&(all.power_t)), "t power value")
795     ("decay_learning_rate",    po::value<float>(&(all.eta_decay_rate)),
796      "Set Decay factor for learning_rate between passes")
797     ("initial_t", po::value<double>(&((all.sd->t))), "initial t value")
798     ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated.  If no initial_regressor given, also used for initial weights.");
799   add_options(all);
800 
801   new_options(all, "Weight options")
802     ("initial_regressor,i", po::value< vector<string> >(), "Initial regressor(s)")
803     ("initial_weight", po::value<float>(&(all.initial_weight)), "Set all weights to an initial value of arg.")
804     ("random_weights", po::value<bool>(&(all.random_weights)), "make initial weights random")
805     ("input_feature_regularizer", po::value< string >(&(all.per_feature_regularizer_input)), "Per feature regularization input file");
806   add_options(all);
807 
808   new_options(all, "Parallelization options")
809     ("span_server", po::value<string>(&(all.span_server)), "Location of server for setting up spanning tree")
810     ("unique_id", po::value<size_t>(&(all.unique_id)),"unique id used for cluster parallel jobs")
811     ("total", po::value<size_t>(&(all.total)),"total number of nodes used in cluster parallel job")
812     ("node", po::value<size_t>(&(all.node)),"node number in cluster parallel job");
813   add_options(all);
814 
815   po::variables_map& vm = all.vm;
816   msrand48(random_seed);
817   parse_diagnostics(all, argc);
818 
819   all.sd->weighted_unlabeled_examples = all.sd->t;
820   all.initial_t = (float)all.sd->t;
821 
822   //Input regressor header
823   io_buf io_temp;
824   parse_regressor_args(all, vm, io_temp);
825 
826   int temp_argc = 0;
827   char** temp_argv = VW::get_argv_from_string(all.file_options->str(), temp_argc);
828   add_to_args(all, temp_argc, temp_argv);
829   for (int i = 0; i < temp_argc; i++)
830     free(temp_argv[i]);
831   free(temp_argv);
832 
833   po::parsed_options pos = po::command_line_parser(all.args).
834     style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
835     options(all.opts).allow_unregistered().run();
836 
837   vm = po::variables_map();
838 
839   po::store(pos, vm);
840   po::notify(vm);
841   all.file_options->str("");
842 
843   parse_feature_tweaks(all); //feature tweaks
844 
845   parse_example_tweaks(all); //example manipulation
846 
847   parse_output_model(all);
848 
849   parse_output_preds(all);
850 
851   parse_reductions(all);
852 
853   if (!all.quiet)
854     {
855       cerr << "Num weight bits = " << all.num_bits << endl;
856       cerr << "learning rate = " << all.eta << endl;
857       cerr << "initial_t = " << all.sd->t << endl;
858       cerr << "power_t = " << all.power_t << endl;
859       if (all.numpasses > 1)
860 	cerr << "decay_learning_rate = " << all.eta_decay_rate << endl;
861     }
862 
863   load_input_model(all, io_temp);
864 
865   parse_source(all);
866 
867   enable_sources(all, all.quiet, all.numpasses);
868 
869   // force wpp to be a power of 2 to avoid 32-bit overflow
870   uint32_t i = 0;
871   size_t params_per_problem = all.l->increment;
872   while (params_per_problem > (uint32_t)(1 << i))
873     i++;
874   all.wpp = (1 << i) >> all.reg.stride_shift;
875 
876   if (vm.count("help")) {
877     /* upon direct query for help -- spit it out to stdout */
878     cout << "\n" << all.opts << "\n";
879     exit(0);
880   }
881 
882   return all;
883 }
884 
885 namespace VW {
cmd_string_replace_value(std::stringstream * & ss,string flag_to_replace,string new_value)886   void cmd_string_replace_value( std::stringstream*& ss, string flag_to_replace, string new_value )
887   {
888     flag_to_replace.append(" "); //add a space to make sure we obtain the right flag in case 2 flags start with the same set of characters
889     string cmd = ss->str();
890     size_t pos = cmd.find(flag_to_replace);
891     if( pos == string::npos )
892       //flag currently not present in command string, so just append it to command string
893       *ss << " " << flag_to_replace << new_value;
894     else {
895       //flag is present, need to replace old value with new value
896 
897       //compute position after flag_to_replace
898       pos += flag_to_replace.size();
899 
900       //now pos is position where value starts
901       //find position of next space
902       size_t pos_after_value = cmd.find(" ",pos);
903       if(pos_after_value == string::npos)
904         //we reach the end of the string, so replace the all characters after pos by new_value
905         cmd.replace(pos,cmd.size()-pos,new_value);
906       else
907         //replace characters between pos and pos_after_value by new_value
908         cmd.replace(pos,pos_after_value-pos,new_value);
909       ss->str(cmd);
910     }
911   }
912 
get_argv_from_string(string s,int & argc)913   char** get_argv_from_string(string s, int& argc)
914   {
915     char* c = calloc_or_die<char>(s.length()+3);
916     c[0] = 'b';
917     c[1] = ' ';
918     strcpy(c+2, s.c_str());
919     substring ss = {c, c+s.length()+2};
920     v_array<substring> foo = v_init<substring>();
921     tokenize(' ', ss, foo);
922 
923     char** argv = calloc_or_die<char*>(foo.size());
924     for (size_t i = 0; i < foo.size(); i++)
925       {
926 	*(foo[i].end) = '\0';
927 	argv[i] = calloc_or_die<char>(foo[i].end-foo[i].begin+1);
928         sprintf(argv[i],"%s",foo[i].begin);
929       }
930 
931     argc = (int)foo.size();
932     free(c);
933     foo.delete_v();
934     return argv;
935   }
936 
initialize(string s)937   vw* initialize(string s)
938   {
939     int argc = 0;
940     s += " --no_stdin";
941     char** argv = get_argv_from_string(s,argc);
942 
943     vw& all = parse_args(argc, argv);
944 
945     initialize_parser_datastructures(all);
946 
947     for(int i = 0; i < argc; i++)
948       free(argv[i]);
949     free(argv);
950 
951     return &all;
952   }
953 
delete_dictionary_entry(substring ss,v_array<feature> * A)954   void delete_dictionary_entry(substring ss, v_array<feature>*A) {
955     free(ss.begin);
956     A->delete_v();
957     delete A;
958   }
959 
finish(vw & all,bool delete_all)960   void finish(vw& all, bool delete_all)
961   {
962     finalize_regressor(all, all.final_regressor_name);
963     all.l->finish();
964     free_it(all.l);
965     if (all.reg.weight_vector != NULL)
966       free(all.reg.weight_vector);
967     free_parser(all);
968     finalize_source(all.p);
969     all.p->parse_name.erase();
970     all.p->parse_name.delete_v();
971     free(all.p);
972     free(all.sd);
973     all.reduction_stack.delete_v();
974     delete all.file_options;
975     for (size_t i = 0; i < all.final_prediction_sink.size(); i++)
976       if (all.final_prediction_sink[i] != 1)
977 	io_buf::close_file_or_socket(all.final_prediction_sink[i]);
978     all.final_prediction_sink.delete_v();
979     for (size_t i=0; i<all.read_dictionaries.size(); i++) {
980       free(all.read_dictionaries[i].name);
981       all.read_dictionaries[i].dict->iter(delete_dictionary_entry);
982       all.read_dictionaries[i].dict->delete_v();
983       delete all.read_dictionaries[i].dict;
984     }
985     delete all.loss;
986     if (delete_all) delete &all;
987   }
988 }
989