1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5 */
6 #include <stdio.h>
7 #include <float.h>
8 #include <sstream>
9 #include <fstream>
10
11 #include "parse_regressor.h"
12 #include "parser.h"
13 #include "vw.h"
14
15 #include "sender.h"
16 #include "nn.h"
17 #include "gd.h"
18 #include "cbify.h"
19 #include "oaa.h"
20 #include "rand48.h"
21 #include "bs.h"
22 #include "topk.h"
23 #include "ect.h"
24 #include "csoaa.h"
25 #include "cb_algs.h"
26 #include "scorer.h"
27 #include "search.h"
28 #include "bfgs.h"
29 #include "lda_core.h"
30 #include "noop.h"
31 #include "print.h"
32 #include "gd_mf.h"
33 #include "mf.h"
34 #include "ftrl_proximal.h"
35 #include "rand48.h"
36 #include "binary.h"
37 #include "lrq.h"
38 #include "autolink.h"
39 #include "log_multi.h"
40 #include "stagewise_poly.h"
41 #include "active.h"
42 #include "kernel_svm.h"
43 #include "parse_example.h"
44
45 using namespace std;
46 //
47 // Does string end with a certain substring?
48 //
ends_with(string const & fullString,string const & ending)49 bool ends_with(string const &fullString, string const &ending)
50 {
51 if (fullString.length() > ending.length()) {
52 return (fullString.compare(fullString.length() - ending.length(), ending.length(), ending) == 0);
53 } else {
54 return false;
55 }
56 }
57
valid_ns(char c)58 bool valid_ns(char c)
59 {
60 if (c=='|'||c==':')
61 return false;
62 return true;
63 }
64
substring_equal(substring & a,substring & b)65 bool substring_equal(substring&a, substring&b) {
66 return (a.end - a.begin == b.end - b.begin) // same length
67 && (strncmp(a.begin, b.begin, a.end - a.begin) == 0);
68 }
69
parse_dictionary_argument(vw & all,string str)70 void parse_dictionary_argument(vw&all, string str) {
71 if (str.length() == 0) return;
72 // expecting 'namespace:file', for instance 'w:foo.txt'
73 // in the case of just 'foo.txt' it's applied to the default namespace
74
75 char ns = ' ';
76 const char*s = str.c_str();
77 if ((str.length() > 3) && (str[1] == ':')) {
78 ns = str[0];
79 s += 2;
80 }
81
82 // see if we've already read this dictionary
83 for (size_t id=0; id<all.read_dictionaries.size(); id++)
84 if (strcmp(all.read_dictionaries[id].name, s) == 0) {
85 all.namespace_dictionaries[(size_t)ns].push_back(all.read_dictionaries[id].dict);
86 return;
87 }
88
89 feature_dict* map = new feature_dict(1023, NULL, substring_equal);
90
91 // TODO: handle gzipped dictionaries
92 example *ec = alloc_examples(all.p->lp.label_size, 1);
93 ifstream infile(s);
94 size_t def = (size_t)' ';
95 for (string line; getline(infile, line);) {
96 char* c = (char*)line.c_str(); // we're throwing away const, which is dangerous...
97 while (*c == ' ' || *c == '\t') ++c; // skip initial whitespace
98 char* d = c;
99 while (*d != ' ' && *d != '\t' && *d != '\n' && *d != '\0') ++d; // gobble up initial word
100 if (d == c) continue; // no word
101 if (*d != ' ' && *d != '\t') continue; // reached end of line
102 char* word = calloc_or_die<char>(d-c);
103 memcpy(word, c, d-c);
104 substring ss = { word, word + (d - c) };
105 uint32_t hash = uniform_hash( ss.begin, ss.end-ss.begin, quadratic_constant);
106 if (map->get(ss, hash) != NULL) { // don't overwrite old values!
107 free(word);
108 continue;
109 }
110
111 d--;
112 *d = '|'; // set up for parser::read_line
113 read_line(all, ec, d);
114 // now we just need to grab stuff from the default namespace of ec!
115 if (ec->atomics[def].size() == 0) {
116 free(word);
117 continue;
118 }
119 v_array<feature>* arr = new v_array<feature>;
120 *arr = v_init<feature>();
121 push_many(*arr, ec->atomics[def].begin, ec->atomics[def].size());
122 map->put(ss, hash, arr);
123 }
124 dealloc_example(all.p->lp.delete_label, *ec);
125 free(ec);
126
127 cerr << "dictionary " << s << " contains " << map->size() << " item" << (map->size() == 1 ? "\n" : "s\n");
128 all.namespace_dictionaries[(size_t)ns].push_back(map);
129 dictionary_info info = { calloc_or_die<char>(strlen(s)+1), map };
130 strcpy(info.name, s);
131 all.read_dictionaries.push_back(info);
132 }
133
parse_affix_argument(vw & all,string str)134 void parse_affix_argument(vw&all, string str) {
135 if (str.length() == 0) return;
136 char* cstr = calloc_or_die<char>(str.length()+1);
137 strcpy(cstr, str.c_str());
138
139 char*p = strtok(cstr, ",");
140 while (p != 0) {
141 char*q = p;
142 uint16_t prefix = 1;
143 if (q[0] == '+') { q++; }
144 else if (q[0] == '-') { prefix = 0; q++; }
145 if ((q[0] < '1') || (q[0] > '7')) {
146 cerr << "malformed affix argument (length must be 1..7): " << p << endl;
147 throw exception();
148 }
149 uint16_t len = (uint16_t)(q[0] - '0');
150 uint16_t ns = (uint16_t)' '; // default namespace
151 if (q[1] != 0) {
152 if (valid_ns(q[1]))
153 ns = (uint16_t)q[1];
154 else {
155 cerr << "malformed affix argument (invalid namespace): " << p << endl;
156 throw exception();
157 }
158 if (q[2] != 0) {
159 cerr << "malformed affix argument (too long): " << p << endl;
160 throw exception();
161 }
162 }
163
164 uint16_t afx = (len << 1) | (prefix & 0x1);
165 all.affix_features[ns] <<= 4;
166 all.affix_features[ns] |= afx;
167
168 p = strtok(NULL, ",");
169 }
170
171 free(cstr);
172 }
173
parse_diagnostics(vw & all,int argc)174 void parse_diagnostics(vw& all, int argc)
175 {
176 new_options(all, "Diagnostic options")
177 ("version","Version information")
178 ("audit,a", "print weights of features")
179 ("progress,P", po::value< string >(), "Progress update frequency. int: additive, float: multiplicative")
180 ("quiet", "Don't output disgnostics and progress updates")
181 ("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.");
182 add_options(all);
183
184 po::variables_map& vm = all.vm;
185
186 if (vm.count("version")) {
187 /* upon direct query for version -- spit it out to stdout */
188 cout << version.to_string() << "\n";
189 exit(0);
190 }
191
192 if (vm.count("quiet")) {
193 all.quiet = true;
194 // --quiet wins over --progress
195 } else {
196 if (argc == 1)
197 cerr << "For more information use: vw --help" << endl;
198
199 all.quiet = false;
200
201 if (vm.count("progress")) {
202 string progress_str = vm["progress"].as<string>();
203 all.progress_arg = (float)::atof(progress_str.c_str());
204
205 // --progress interval is dual: either integer or floating-point
206 if (progress_str.find_first_of(".") == string::npos) {
207 // No "." in arg: assume integer -> additive
208 all.progress_add = true;
209 if (all.progress_arg < 1) {
210 cerr << "warning: additive --progress <int>"
211 << " can't be < 1: forcing to 1\n";
212 all.progress_arg = 1;
213
214 }
215 all.sd->dump_interval = all.progress_arg;
216
217 } else {
218 // A "." in arg: assume floating-point -> multiplicative
219 all.progress_add = false;
220
221 if (all.progress_arg <= 1.0) {
222 cerr << "warning: multiplicative --progress <float>: "
223 << vm["progress"].as<string>()
224 << " is <= 1.0: adding 1.0\n";
225 all.progress_arg += 1.0;
226
227 } else if (all.progress_arg > 9.0) {
228 cerr << "warning: multiplicative --progress <float>"
229 << " is > 9.0: you probably meant to use an integer\n";
230 }
231 all.sd->dump_interval = 1.0;
232 }
233 }
234 }
235
236 if (vm.count("audit")){
237 all.audit = true;
238 }
239 }
240
parse_source(vw & all)241 void parse_source(vw& all)
242 {
243 new_options(all, "Input options")
244 ("data,d", po::value< string >(), "Example Set")
245 ("daemon", "persistent daemon mode on port 26542")
246 ("port", po::value<size_t>(),"port to listen on; use 0 to pick unused port")
247 ("num_children", po::value<size_t>(&(all.num_children)), "number of children for persistent daemon mode")
248 ("pid_file", po::value< string >(), "Write pid file in persistent daemon mode")
249 ("port_file", po::value< string >(), "Write port used in persistent daemon mode")
250 ("cache,c", "Use a cache. The default is <data>.cache")
251 ("cache_file", po::value< vector<string> >(), "The location(s) of cache_file.")
252 ("kill_cache,k", "do not reuse existing cache: create a new one always")
253 ("compressed", "use gzip format whenever possible. If a cache file is being created, this option creates a compressed cache file. A mixture of raw-text & compressed inputs are supported with autodetection.")
254 ("no_stdin", "do not default to reading from stdin");
255 add_options(all);
256
257 // Be friendly: if -d was left out, treat positional param as data file
258 po::positional_options_description p;
259 p.add("data", -1);
260
261 po::parsed_options pos = po::command_line_parser(all.args).
262 style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
263 options(all.opts).positional(p).run();
264 all.vm = po::variables_map();
265 po::store(pos, all.vm);
266 po::variables_map& vm = all.vm;
267
268 //begin input source
269 if (vm.count("no_stdin"))
270 all.stdin_off = true;
271
272 if ( (vm.count("total") || vm.count("node") || vm.count("unique_id")) && !(vm.count("total") && vm.count("node") && vm.count("unique_id")) )
273 {
274 cout << "you must specificy unique_id, total, and node if you specify any" << endl;
275 throw exception();
276 }
277
278 if (vm.count("daemon") || vm.count("pid_file") || (vm.count("port") && !all.active) ) {
279 all.daemon = true;
280
281 // allow each child to process up to 1e5 connections
282 all.numpasses = (size_t) 1e5;
283 }
284
285 if (vm.count("compressed"))
286 set_compressed(all.p);
287
288 if (vm.count("data")) {
289 all.data_filename = vm["data"].as<string>();
290 if (ends_with(all.data_filename, ".gz"))
291 set_compressed(all.p);
292 } else
293 all.data_filename = "";
294
295 if ((vm.count("cache") || vm.count("cache_file")) && vm.count("invert_hash"))
296 {
297 cout << "invert_hash is incompatible with a cache file. Use it in single pass mode only." << endl;
298 throw exception();
299 }
300
301 if(!all.holdout_set_off && (vm.count("output_feature_regularizer_binary") || vm.count("output_feature_regularizer_text")))
302 {
303 all.holdout_set_off = true;
304 cerr<<"Making holdout_set_off=true since output regularizer specified\n";
305 }
306 }
307
parse_feature_tweaks(vw & all)308 void parse_feature_tweaks(vw& all)
309 {
310 new_options(all, "Feature options")
311 ("hash", po::value< string > (), "how to hash the features. Available options: strings, all")
312 ("ignore", po::value< vector<unsigned char> >(), "ignore namespaces beginning with character <arg>")
313 ("keep", po::value< vector<unsigned char> >(), "keep namespaces beginning with character <arg>")
314 ("bit_precision,b", po::value<size_t>(), "number of bits in the feature table")
315 ("noconstant", "Don't add a constant feature")
316 ("constant,C", po::value<float>(&(all.initial_constant)), "Set initial value of constant")
317 ("ngram", po::value< vector<string> >(), "Generate N grams. To generate N grams for a single namespace 'foo', arg should be fN.")
318 ("skips", po::value< vector<string> >(), "Generate skips in N grams. This in conjunction with the ngram tag can be used to generate generalized n-skip-k-gram. To generate n-skips for a single namespace 'foo', arg should be fN.")
319 ("feature_limit", po::value< vector<string> >(), "limit to N features. To apply to a single namespace 'foo', arg should be fN")
320 ("affix", po::value<string>(), "generate prefixes/suffixes of features; argument '+2a,-3b,+1' means generate 2-char prefixes for namespace a, 3-char suffixes for b and 1 char prefixes for default namespace")
321 ("spelling", po::value< vector<string> >(), "compute spelling features for a give namespace (use '_' for default namespace)")
322 ("dictionary", po::value< vector<string> >(), "read a dictionary for additional features (arg either 'x:file' or just 'file')")
323 ("quadratic,q", po::value< vector<string> > (), "Create and use quadratic features")
324 ("q:", po::value< string >(), ": corresponds to a wildcard for all printable characters")
325 ("cubic", po::value< vector<string> > (),
326 "Create and use cubic features");
327 add_options(all);
328
329 po::variables_map& vm = all.vm;
330
331 //feature manipulation
332 string hash_function("strings");
333 if(vm.count("hash"))
334 hash_function = vm["hash"].as<string>();
335 all.p->hasher = getHasher(hash_function);
336
337 if (vm.count("spelling")) {
338 vector<string> spelling_ns = vm["spelling"].as< vector<string> >();
339 for (size_t id=0; id<spelling_ns.size(); id++)
340 if (spelling_ns[id][0] == '_') all.spelling_features[(unsigned char)' '] = true;
341 else all.spelling_features[(size_t)spelling_ns[id][0]] = true;
342 }
343
344 if (vm.count("affix")) {
345 parse_affix_argument(all, vm["affix"].as<string>());
346 *all.file_options << " --affix " << vm["affix"].as<string>();
347 }
348
349 if(vm.count("ngram")){
350 if(vm.count("sort_features"))
351 {
352 cerr << "ngram is incompatible with sort_features. " << endl;
353 throw exception();
354 }
355
356 all.ngram_strings = vm["ngram"].as< vector<string> >();
357 compile_gram(all.ngram_strings, all.ngram, (char*)"grams", all.quiet);
358 }
359
360 if(vm.count("skips"))
361 {
362 if(!vm.count("ngram"))
363 {
364 cout << "You can not skip unless ngram is > 1" << endl;
365 throw exception();
366 }
367
368 all.skip_strings = vm["skips"].as<vector<string> >();
369 compile_gram(all.skip_strings, all.skips, (char*)"skips", all.quiet);
370 }
371
372 if(vm.count("feature_limit"))
373 {
374 all.limit_strings = vm["feature_limit"].as< vector<string> >();
375 compile_limits(all.limit_strings, all.limit, all.quiet);
376 }
377
378 if (vm.count("bit_precision"))
379 {
380 uint32_t new_bits = (uint32_t)vm["bit_precision"].as< size_t>();
381 if (all.default_bits == false && new_bits != all.num_bits)
382 {
383 cout << "Number of bits is set to " << new_bits << " and " << all.num_bits << " by argument and model. That does not work." << endl;
384 throw exception();
385 }
386 all.default_bits = false;
387 all.num_bits = new_bits;
388 if (all.num_bits > min(31, sizeof(size_t)*8 - 3))
389 {
390 cout << "Only " << min(31, sizeof(size_t)*8 - 3) << " or fewer bits allowed. If this is a serious limit, speak up." << endl;
391 throw exception();
392 }
393 }
394
395 if (vm.count("quadratic"))
396 {
397 all.pairs = vm["quadratic"].as< vector<string> >();
398 vector<string> newpairs;
399 //string tmp;
400 char printable_start = '!';
401 char printable_end = '~';
402 int valid_ns_size = printable_end - printable_start - 1; //will skip two characters
403
404 if(!all.quiet)
405 cerr<<"creating quadratic features for pairs: ";
406
407 for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++){
408 if(!all.quiet){
409 cerr << *i << " ";
410 if (i->length() > 2)
411 cerr << endl << "warning, ignoring characters after the 2nd.\n";
412 if (i->length() < 2) {
413 cerr << endl << "error, quadratic features must involve two sets.\n";
414 throw exception();
415 }
416 }
417 //-q x:
418 if((*i)[0]!=':'&&(*i)[1]==':'){
419 newpairs.reserve(newpairs.size() + valid_ns_size);
420 for (char j=printable_start; j<=printable_end; j++){
421 if(valid_ns(j))
422 newpairs.push_back(string(1,(*i)[0])+j);
423 }
424 }
425 //-q :x
426 else if((*i)[0]==':'&&(*i)[1]!=':'){
427 newpairs.reserve(newpairs.size() + valid_ns_size);
428 for (char j=printable_start; j<=printable_end; j++){
429 if(valid_ns(j)){
430 stringstream ss;
431 ss << j << (*i)[1];
432 newpairs.push_back(ss.str());
433 }
434 }
435 }
436 //-q ::
437 else if((*i)[0]==':'&&(*i)[1]==':'){
438 cout << "in pair creation" << endl;
439 newpairs.reserve(newpairs.size() + valid_ns_size*valid_ns_size);
440 stringstream ss;
441 ss << ' ' << ' ';
442 newpairs.push_back(ss.str());
443 for (char j=printable_start; j<=printable_end; j++){
444 if(valid_ns(j)){
445 for (char k=printable_start; k<=printable_end; k++){
446 if(valid_ns(k)){
447 stringstream ss;
448 ss << j << k;
449 newpairs.push_back(ss.str());
450 }
451 }
452 }
453 }
454 }
455 else{
456 newpairs.push_back(string(*i));
457 }
458 }
459 newpairs.swap(all.pairs);
460 if(!all.quiet)
461 cerr<<endl;
462 }
463
464 if (vm.count("cubic"))
465 {
466 all.triples = vm["cubic"].as< vector<string> >();
467 if (!all.quiet)
468 {
469 cerr << "creating cubic features for triples: ";
470 for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++) {
471 cerr << *i << " ";
472 if (i->length() > 3)
473 cerr << endl << "warning, ignoring characters after the 3rd.\n";
474 if (i->length() < 3) {
475 cerr << endl << "error, cubic features must involve three sets.\n";
476 throw exception();
477 }
478 }
479 cerr << endl;
480 }
481 }
482
483 for (size_t i = 0; i < 256; i++)
484 all.ignore[i] = false;
485 all.ignore_some = false;
486
487 if (vm.count("ignore"))
488 {
489 all.ignore_some = true;
490
491 vector<unsigned char> ignore = vm["ignore"].as< vector<unsigned char> >();
492 for (vector<unsigned char>::iterator i = ignore.begin(); i != ignore.end();i++)
493 {
494 all.ignore[*i] = true;
495 }
496 if (!all.quiet)
497 {
498 cerr << "ignoring namespaces beginning with: ";
499 for (vector<unsigned char>::iterator i = ignore.begin(); i != ignore.end();i++)
500 cerr << *i << " ";
501
502 cerr << endl;
503 }
504 }
505
506 if (vm.count("keep"))
507 {
508 for (size_t i = 0; i < 256; i++)
509 all.ignore[i] = true;
510
511 all.ignore_some = true;
512
513 vector<unsigned char> keep = vm["keep"].as< vector<unsigned char> >();
514 for (vector<unsigned char>::iterator i = keep.begin(); i != keep.end();i++)
515 {
516 all.ignore[*i] = false;
517 }
518 if (!all.quiet)
519 {
520 cerr << "using namespaces beginning with: ";
521 for (vector<unsigned char>::iterator i = keep.begin(); i != keep.end();i++)
522 cerr << *i << " ";
523
524 cerr << endl;
525 }
526 }
527
528 if (vm.count("dictionary")) {
529 vector<string> dictionary_ns = vm["dictionary"].as< vector<string> >();
530 for (size_t id=0; id<dictionary_ns.size(); id++)
531 parse_dictionary_argument(all, dictionary_ns[id]);
532 }
533
534 if (vm.count("noconstant"))
535 all.add_constant = false;
536 }
537
parse_example_tweaks(vw & all)538 void parse_example_tweaks(vw& all)
539 {
540 new_options(all, "Example options")
541 ("testonly,t", "Ignore label information and just test")
542 ("holdout_off", "no holdout data in multiple passes")
543 ("holdout_period", po::value<uint32_t>(&(all.holdout_period)), "holdout period for test only, default 10")
544 ("holdout_after", po::value<uint32_t>(&(all.holdout_after)), "holdout after n training examples, default off (disables holdout_period)")
545 ("early_terminate", po::value<size_t>(), "Specify the number of passes tolerated when holdout loss doesn't decrease before early termination, default is 3")
546 ("passes", po::value<size_t>(&(all.numpasses)),"Number of Training Passes")
547 ("initial_pass_length", po::value<size_t>(&(all.pass_length)), "initial number of examples per pass")
548 ("examples", po::value<size_t>(&(all.max_examples)), "number of examples to parse")
549 ("min_prediction", po::value<float>(&(all.sd->min_label)), "Smallest prediction to output")
550 ("max_prediction", po::value<float>(&(all.sd->max_label)), "Largest prediction to output")
551 ("sort_features", "turn this on to disregard order in which features have been defined. This will lead to smaller cache sizes")
552 ("loss_function", po::value<string>()->default_value("squared"), "Specify the loss function to be used, uses squared by default. Currently available ones are squared, classic, hinge, logistic and quantile.")
553 ("quantile_tau", po::value<float>()->default_value(0.5), "Parameter \\tau associated with Quantile loss. Defaults to 0.5")
554 ("l1", po::value<float>(&(all.l1_lambda)), "l_1 lambda")
555 ("l2", po::value<float>(&(all.l2_lambda)), "l_2 lambda");
556 add_options(all);
557
558 po::variables_map& vm = all.vm;
559 if (vm.count("testonly") || all.eta == 0.)
560 {
561 if (!all.quiet)
562 cerr << "only testing" << endl;
563 all.training = false;
564 if (all.lda > 0)
565 all.eta = 0;
566 }
567 else
568 all.training = true;
569
570 if(all.numpasses > 1)
571 all.holdout_set_off = false;
572
573 if(vm.count("holdout_off"))
574 all.holdout_set_off = true;
575
576 if(vm.count("sort_features"))
577 all.p->sort_features = true;
578
579 if (vm.count("min_prediction"))
580 all.sd->min_label = vm["min_prediction"].as<float>();
581 if (vm.count("max_prediction"))
582 all.sd->max_label = vm["max_prediction"].as<float>();
583 if (vm.count("min_prediction") || vm.count("max_prediction") || vm.count("testonly"))
584 all.set_minmax = noop_mm;
585
586 string loss_function = vm["loss_function"].as<string>();
587 float loss_parameter = 0.0;
588 if(vm.count("quantile_tau"))
589 loss_parameter = vm["quantile_tau"].as<float>();
590
591 all.loss = getLossFunction(all, loss_function, (float)loss_parameter);
592
593 if (all.l1_lambda < 0.) {
594 cerr << "l1_lambda should be nonnegative: resetting from " << all.l1_lambda << " to 0" << endl;
595 all.l1_lambda = 0.;
596 }
597 if (all.l2_lambda < 0.) {
598 cerr << "l2_lambda should be nonnegative: resetting from " << all.l2_lambda << " to 0" << endl;
599 all.l2_lambda = 0.;
600 }
601 all.reg_mode += (all.l1_lambda > 0.) ? 1 : 0;
602 all.reg_mode += (all.l2_lambda > 0.) ? 2 : 0;
603 if (!all.quiet)
604 {
605 if (all.reg_mode %2 && !vm.count("bfgs"))
606 cerr << "using l1 regularization = " << all.l1_lambda << endl;
607 if (all.reg_mode > 1)
608 cerr << "using l2 regularization = " << all.l2_lambda << endl;
609 }
610 }
611
parse_output_preds(vw & all)612 void parse_output_preds(vw& all)
613 {
614 new_options(all, "Output options")
615 ("predictions,p", po::value< string >(), "File to output predictions to")
616 ("raw_predictions,r", po::value< string >(), "File to output unnormalized predictions to");
617 add_options(all);
618
619 po::variables_map& vm = all.vm;
620 if (vm.count("predictions")) {
621 if (!all.quiet)
622 cerr << "predictions = " << vm["predictions"].as< string >() << endl;
623 if (strcmp(vm["predictions"].as< string >().c_str(), "stdout") == 0)
624 {
625 all.final_prediction_sink.push_back((size_t) 1);//stdout
626 }
627 else
628 {
629 const char* fstr = (vm["predictions"].as< string >().c_str());
630 int f;
631 #ifdef _WIN32
632 _sopen_s(&f, fstr, _O_CREAT|_O_WRONLY|_O_BINARY|_O_TRUNC, _SH_DENYWR, _S_IREAD|_S_IWRITE);
633 #else
634 f = open(fstr, O_CREAT|O_WRONLY|O_LARGEFILE|O_TRUNC,0666);
635 #endif
636 if (f < 0)
637 cerr << "Error opening the predictions file: " << fstr << endl;
638 all.final_prediction_sink.push_back((size_t) f);
639 }
640 }
641
642 if (vm.count("raw_predictions")) {
643 if (!all.quiet) {
644 cerr << "raw predictions = " << vm["raw_predictions"].as< string >() << endl;
645 if (vm.count("binary"))
646 cerr << "Warning: --raw has no defined value when --binary specified, expect no output" << endl;
647 }
648 if (strcmp(vm["raw_predictions"].as< string >().c_str(), "stdout") == 0)
649 all.raw_prediction = 1;//stdout
650 else
651 {
652 const char* t = vm["raw_predictions"].as< string >().c_str();
653 int f;
654 #ifdef _WIN32
655 _sopen_s(&f, t, _O_CREAT|_O_WRONLY|_O_BINARY|_O_TRUNC, _SH_DENYWR, _S_IREAD|_S_IWRITE);
656 #else
657 f = open(t, O_CREAT|O_WRONLY|O_LARGEFILE|O_TRUNC,0666);
658 #endif
659 all.raw_prediction = f;
660 }
661 }
662 }
663
parse_output_model(vw & all)664 void parse_output_model(vw& all)
665 {
666 new_options(all, "Output model")
667 ("final_regressor,f", po::value< string >(), "Final regressor")
668 ("readable_model", po::value< string >(), "Output human-readable final regressor with numeric features")
669 ("invert_hash", po::value< string >(), "Output human-readable final regressor with feature names. Computationally expensive.")
670 ("save_resume", "save extra state so learning can be resumed later with new data")
671 ("save_per_pass", "Save the model after every pass over data")
672 ("output_feature_regularizer_binary", po::value< string >(&(all.per_feature_regularizer_output)), "Per feature regularization output file")
673 ("output_feature_regularizer_text", po::value< string >(&(all.per_feature_regularizer_text)), "Per feature regularization output file, in text");
674 add_options(all);
675
676 po::variables_map& vm = all.vm;
677 if (vm.count("final_regressor")) {
678 all.final_regressor_name = vm["final_regressor"].as<string>();
679 if (!all.quiet)
680 cerr << "final_regressor = " << vm["final_regressor"].as<string>() << endl;
681 }
682 else
683 all.final_regressor_name = "";
684
685 if (vm.count("readable_model"))
686 all.text_regressor_name = vm["readable_model"].as<string>();
687
688 if (vm.count("invert_hash")){
689 all.inv_hash_regressor_name = vm["invert_hash"].as<string>();
690 all.hash_inv = true;
691 }
692
693 if (vm.count("save_per_pass"))
694 all.save_per_pass = true;
695
696 if (vm.count("save_resume"))
697 all.save_resume = true;
698 }
699
load_input_model(vw & all,io_buf & io_temp)700 void load_input_model(vw& all, io_buf& io_temp)
701 {
702 // Need to see if we have to load feature mask first or second.
703 // -i and -mask are from same file, load -i file first so mask can use it
704 if (all.vm.count("feature_mask") && all.vm.count("initial_regressor")
705 && all.vm["feature_mask"].as<string>() == all.vm["initial_regressor"].as< vector<string> >()[0]) {
706 // load rest of regressor
707 all.l->save_load(io_temp, true, false);
708 io_temp.close_file();
709
710 // set the mask, which will reuse -i file we just loaded
711 parse_mask_regressor_args(all);
712 }
713 else {
714 // load mask first
715 parse_mask_regressor_args(all);
716
717 // load rest of regressor
718 all.l->save_load(io_temp, true, false);
719 io_temp.close_file();
720 }
721 }
722
setup_base(vw & all)723 LEARNER::base_learner* setup_base(vw& all)
724 {
725 LEARNER::base_learner* ret = all.reduction_stack.pop()(all);
726 if (ret == NULL)
727 return setup_base(all);
728 else
729 return ret;
730 }
731
parse_reductions(vw & all)732 void parse_reductions(vw& all)
733 {
734 new_options(all, "Reduction options, use [option] --help for more info");
735 add_options(all);
736 //Base algorithms
737 all.reduction_stack.push_back(GD::setup);
738 all.reduction_stack.push_back(kernel_svm_setup);
739 all.reduction_stack.push_back(ftrl_setup);
740 all.reduction_stack.push_back(sender_setup);
741 all.reduction_stack.push_back(gd_mf_setup);
742 all.reduction_stack.push_back(print_setup);
743 all.reduction_stack.push_back(noop_setup);
744 all.reduction_stack.push_back(lda_setup);
745 all.reduction_stack.push_back(bfgs_setup);
746
747 //Score Users
748 all.reduction_stack.push_back(active_setup);
749 all.reduction_stack.push_back(nn_setup);
750 all.reduction_stack.push_back(mf_setup);
751 all.reduction_stack.push_back(autolink_setup);
752 all.reduction_stack.push_back(lrq_setup);
753 all.reduction_stack.push_back(stagewise_poly_setup);
754 all.reduction_stack.push_back(scorer_setup);
755
756 //Reductions
757 all.reduction_stack.push_back(binary_setup);
758 all.reduction_stack.push_back(topk_setup);
759 all.reduction_stack.push_back(oaa_setup);
760 all.reduction_stack.push_back(ect_setup);
761 all.reduction_stack.push_back(log_multi_setup);
762 all.reduction_stack.push_back(csoaa_setup);
763 all.reduction_stack.push_back(csldf_setup);
764 all.reduction_stack.push_back(cb_algs_setup);
765 all.reduction_stack.push_back(cbify_setup);
766 all.reduction_stack.push_back(Search::setup);
767 all.reduction_stack.push_back(bs_setup);
768
769 all.l = setup_base(all);
770 }
771
add_to_args(vw & all,int argc,char * argv[])772 void add_to_args(vw& all, int argc, char* argv[])
773 {
774 for (int i = 1; i < argc; i++)
775 all.args.push_back(string(argv[i]));
776 }
777
parse_args(int argc,char * argv[])778 vw& parse_args(int argc, char *argv[])
779 {
780 vw& all = *(new vw());
781
782 add_to_args(all, argc, argv);
783
784 size_t random_seed = 0;
785 all.program_name = argv[0];
786
787 new_options(all, "VW options")
788 ("random_seed", po::value<size_t>(&random_seed), "seed random number generator")
789 ("ring_size", po::value<size_t>(&(all.p->ring_size)), "size of example ring");
790 add_options(all);
791
792 new_options(all, "Update options")
793 ("learning_rate,l", po::value<float>(&(all.eta)), "Set learning rate")
794 ("power_t", po::value<float>(&(all.power_t)), "t power value")
795 ("decay_learning_rate", po::value<float>(&(all.eta_decay_rate)),
796 "Set Decay factor for learning_rate between passes")
797 ("initial_t", po::value<double>(&((all.sd->t))), "initial t value")
798 ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights.");
799 add_options(all);
800
801 new_options(all, "Weight options")
802 ("initial_regressor,i", po::value< vector<string> >(), "Initial regressor(s)")
803 ("initial_weight", po::value<float>(&(all.initial_weight)), "Set all weights to an initial value of arg.")
804 ("random_weights", po::value<bool>(&(all.random_weights)), "make initial weights random")
805 ("input_feature_regularizer", po::value< string >(&(all.per_feature_regularizer_input)), "Per feature regularization input file");
806 add_options(all);
807
808 new_options(all, "Parallelization options")
809 ("span_server", po::value<string>(&(all.span_server)), "Location of server for setting up spanning tree")
810 ("unique_id", po::value<size_t>(&(all.unique_id)),"unique id used for cluster parallel jobs")
811 ("total", po::value<size_t>(&(all.total)),"total number of nodes used in cluster parallel job")
812 ("node", po::value<size_t>(&(all.node)),"node number in cluster parallel job");
813 add_options(all);
814
815 po::variables_map& vm = all.vm;
816 msrand48(random_seed);
817 parse_diagnostics(all, argc);
818
819 all.sd->weighted_unlabeled_examples = all.sd->t;
820 all.initial_t = (float)all.sd->t;
821
822 //Input regressor header
823 io_buf io_temp;
824 parse_regressor_args(all, vm, io_temp);
825
826 int temp_argc = 0;
827 char** temp_argv = VW::get_argv_from_string(all.file_options->str(), temp_argc);
828 add_to_args(all, temp_argc, temp_argv);
829 for (int i = 0; i < temp_argc; i++)
830 free(temp_argv[i]);
831 free(temp_argv);
832
833 po::parsed_options pos = po::command_line_parser(all.args).
834 style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
835 options(all.opts).allow_unregistered().run();
836
837 vm = po::variables_map();
838
839 po::store(pos, vm);
840 po::notify(vm);
841 all.file_options->str("");
842
843 parse_feature_tweaks(all); //feature tweaks
844
845 parse_example_tweaks(all); //example manipulation
846
847 parse_output_model(all);
848
849 parse_output_preds(all);
850
851 parse_reductions(all);
852
853 if (!all.quiet)
854 {
855 cerr << "Num weight bits = " << all.num_bits << endl;
856 cerr << "learning rate = " << all.eta << endl;
857 cerr << "initial_t = " << all.sd->t << endl;
858 cerr << "power_t = " << all.power_t << endl;
859 if (all.numpasses > 1)
860 cerr << "decay_learning_rate = " << all.eta_decay_rate << endl;
861 }
862
863 load_input_model(all, io_temp);
864
865 parse_source(all);
866
867 enable_sources(all, all.quiet, all.numpasses);
868
869 // force wpp to be a power of 2 to avoid 32-bit overflow
870 uint32_t i = 0;
871 size_t params_per_problem = all.l->increment;
872 while (params_per_problem > (uint32_t)(1 << i))
873 i++;
874 all.wpp = (1 << i) >> all.reg.stride_shift;
875
876 if (vm.count("help")) {
877 /* upon direct query for help -- spit it out to stdout */
878 cout << "\n" << all.opts << "\n";
879 exit(0);
880 }
881
882 return all;
883 }
884
885 namespace VW {
cmd_string_replace_value(std::stringstream * & ss,string flag_to_replace,string new_value)886 void cmd_string_replace_value( std::stringstream*& ss, string flag_to_replace, string new_value )
887 {
888 flag_to_replace.append(" "); //add a space to make sure we obtain the right flag in case 2 flags start with the same set of characters
889 string cmd = ss->str();
890 size_t pos = cmd.find(flag_to_replace);
891 if( pos == string::npos )
892 //flag currently not present in command string, so just append it to command string
893 *ss << " " << flag_to_replace << new_value;
894 else {
895 //flag is present, need to replace old value with new value
896
897 //compute position after flag_to_replace
898 pos += flag_to_replace.size();
899
900 //now pos is position where value starts
901 //find position of next space
902 size_t pos_after_value = cmd.find(" ",pos);
903 if(pos_after_value == string::npos)
904 //we reach the end of the string, so replace the all characters after pos by new_value
905 cmd.replace(pos,cmd.size()-pos,new_value);
906 else
907 //replace characters between pos and pos_after_value by new_value
908 cmd.replace(pos,pos_after_value-pos,new_value);
909 ss->str(cmd);
910 }
911 }
912
get_argv_from_string(string s,int & argc)913 char** get_argv_from_string(string s, int& argc)
914 {
915 char* c = calloc_or_die<char>(s.length()+3);
916 c[0] = 'b';
917 c[1] = ' ';
918 strcpy(c+2, s.c_str());
919 substring ss = {c, c+s.length()+2};
920 v_array<substring> foo = v_init<substring>();
921 tokenize(' ', ss, foo);
922
923 char** argv = calloc_or_die<char*>(foo.size());
924 for (size_t i = 0; i < foo.size(); i++)
925 {
926 *(foo[i].end) = '\0';
927 argv[i] = calloc_or_die<char>(foo[i].end-foo[i].begin+1);
928 sprintf(argv[i],"%s",foo[i].begin);
929 }
930
931 argc = (int)foo.size();
932 free(c);
933 foo.delete_v();
934 return argv;
935 }
936
initialize(string s)937 vw* initialize(string s)
938 {
939 int argc = 0;
940 s += " --no_stdin";
941 char** argv = get_argv_from_string(s,argc);
942
943 vw& all = parse_args(argc, argv);
944
945 initialize_parser_datastructures(all);
946
947 for(int i = 0; i < argc; i++)
948 free(argv[i]);
949 free(argv);
950
951 return &all;
952 }
953
delete_dictionary_entry(substring ss,v_array<feature> * A)954 void delete_dictionary_entry(substring ss, v_array<feature>*A) {
955 free(ss.begin);
956 A->delete_v();
957 delete A;
958 }
959
finish(vw & all,bool delete_all)960 void finish(vw& all, bool delete_all)
961 {
962 finalize_regressor(all, all.final_regressor_name);
963 all.l->finish();
964 free_it(all.l);
965 if (all.reg.weight_vector != NULL)
966 free(all.reg.weight_vector);
967 free_parser(all);
968 finalize_source(all.p);
969 all.p->parse_name.erase();
970 all.p->parse_name.delete_v();
971 free(all.p);
972 free(all.sd);
973 all.reduction_stack.delete_v();
974 delete all.file_options;
975 for (size_t i = 0; i < all.final_prediction_sink.size(); i++)
976 if (all.final_prediction_sink[i] != 1)
977 io_buf::close_file_or_socket(all.final_prediction_sink[i]);
978 all.final_prediction_sink.delete_v();
979 for (size_t i=0; i<all.read_dictionaries.size(); i++) {
980 free(all.read_dictionaries[i].name);
981 all.read_dictionaries[i].dict->iter(delete_dictionary_entry);
982 all.read_dictionaries[i].dict->delete_v();
983 delete all.read_dictionaries[i].dict;
984 }
985 delete all.loss;
986 if (delete_all) delete &all;
987 }
988 }
989