1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD
4 license as described in the file LICENSE.
5 */
6 #pragma once
7 #include <vector>
8 #include <map>
9 #include <stdint.h>
10 #include <cstdio>
11 #include <boost/program_options.hpp>
12 namespace po = boost::program_options;
13
14 #include "v_array.h"
15 #include "parse_primitives.h"
16 #include "loss_functions.h"
17 #include "comp_io.h"
18 #include "example.h"
19 #include "config.h"
20 #include "learner.h"
21 #include "allreduce.h"
22 #include "v_hashmap.h"
23
24 struct version_struct {
25 int major;
26 int minor;
27 int rev;
version_structversion_struct28 version_struct(int maj, int min, int rv)
29 {
30 major = maj;
31 minor = min;
32 rev = rv;
33 }
version_structversion_struct34 version_struct(const char* v_str)
35 {
36 from_string(v_str);
37 }
38 void operator=(version_struct v){
39 major = v.major;
40 minor = v.minor;
41 rev = v.rev;
42 }
43 void operator=(const char* v_str){
44 from_string(v_str);
45 }
46 bool operator==(version_struct v){
47 return (major == v.major && minor == v.minor && rev == v.rev);
48 }
49 bool operator==(const char* v_str){
50 version_struct v_tmp(v_str);
51 return (*this == v_tmp);
52 }
53 bool operator!=(version_struct v){
54 return !(*this == v);
55 }
56 bool operator!=(const char* v_str){
57 version_struct v_tmp(v_str);
58 return (*this != v_tmp);
59 }
60 bool operator>=(version_struct v){
61 if(major < v.major) return false;
62 if(major > v.major) return true;
63 if(minor < v.minor) return false;
64 if(minor > v.minor) return true;
65 if(rev >= v.rev ) return true;
66 return false;
67 }
68 bool operator>=(const char* v_str){
69 version_struct v_tmp(v_str);
70 return (*this >= v_tmp);
71 }
72 bool operator>(version_struct v){
73 if(major < v.major) return false;
74 if(major > v.major) return true;
75 if(minor < v.minor) return false;
76 if(minor > v.minor) return true;
77 if(rev > v.rev ) return true;
78 return false;
79 }
80 bool operator>(const char* v_str){
81 version_struct v_tmp(v_str);
82 return (*this > v_tmp);
83 }
84 bool operator<=(version_struct v){
85 return !(*this < v);
86 }
87 bool operator<=(const char* v_str){
88 version_struct v_tmp(v_str);
89 return (*this <= v_tmp);
90 }
91 bool operator<(version_struct v){
92 return !(*this >= v);
93 }
94 bool operator<(const char* v_str){
95 version_struct v_tmp(v_str);
96 return (*this < v_tmp);
97 }
to_stringversion_struct98 std::string to_string() const
99 {
100 char v_str[128];
101 std::sprintf(v_str,"%d.%d.%d",major,minor,rev);
102 std::string s = v_str;
103 return s;
104 }
from_stringversion_struct105 void from_string(const char* str)
106 {
107 std::sscanf(str,"%d.%d.%d",&major,&minor,&rev);
108 }
109 };
110
111 const version_struct version(PACKAGE_VERSION);
112
113 typedef float weight;
114
115 struct regressor {
116 weight* weight_vector;
117 size_t weight_mask; // (stride*(1 << num_bits) -1)
118 uint32_t stride_shift;
119 };
120
121 typedef v_hashmap< substring, v_array<feature>* > feature_dict;
122 struct dictionary_info {
123 char* name;
124 feature_dict* dict;
125 };
126
127 struct shared_data {
128 size_t queries;
129
130 uint64_t example_number;
131 uint64_t total_features;
132
133 double t;
134 double weighted_examples;
135 double weighted_unlabeled_examples;
136 double old_weighted_examples;
137 double weighted_labels;
138 double sum_loss;
139 double sum_loss_since_last_dump;
140 float dump_interval;// when should I update for the user.
141 double gravity;
142 double contraction;
143 float min_label;//minimum label encountered
144 float max_label;//maximum label encountered
145
146 //for holdout
147 double weighted_holdout_examples;
148 double weighted_holdout_examples_since_last_dump;
149 double holdout_sum_loss_since_last_dump;
150 double holdout_sum_loss;
151 //for best model selection
152 double holdout_best_loss;
153 double weighted_holdout_examples_since_last_pass;//reserved for best predictor selection
154 double holdout_sum_loss_since_last_pass;
155 size_t holdout_best_pass;
156
updateshared_data157 void update(bool test_example, float loss, float weight, size_t num_features)
158 {
159 if(test_example)
160 {
161 weighted_holdout_examples += weight;//test weight seen
162 weighted_holdout_examples_since_last_dump += weight;
163 weighted_holdout_examples_since_last_pass += weight;
164 holdout_sum_loss += loss;
165 holdout_sum_loss_since_last_dump += loss;
166 holdout_sum_loss_since_last_pass += loss;//since last pass
167 }
168 else
169 {
170 weighted_examples += weight;
171 sum_loss += loss;
172 sum_loss_since_last_dump += loss;
173 total_features += num_features;
174 example_number++;
175 }
176 }
177
update_dump_intervalshared_data178 inline void update_dump_interval(bool progress_add, float progress_arg) {
179 sum_loss_since_last_dump = 0.0;
180 old_weighted_examples = weighted_examples;
181 if (progress_add)
182 dump_interval = (float)weighted_examples + progress_arg;
183 else
184 dump_interval = (float)weighted_examples * progress_arg;
185 }
186
print_updateshared_data187 void print_update(bool holdout_set_off, size_t current_pass, char* label_buf, char* pred_buf,
188 size_t num_features, bool progress_add, float progress_arg)
189 {
190 if(!holdout_set_off && current_pass >= 1)
191 {
192 if(holdout_sum_loss == 0. && weighted_holdout_examples == 0.)
193 fprintf(stderr, " unknown ");
194 else
195 fprintf(stderr, "%-10.6f " , holdout_sum_loss/weighted_holdout_examples);
196
197 if(holdout_sum_loss_since_last_dump == 0. && weighted_holdout_examples_since_last_dump == 0.)
198 fprintf(stderr, " unknown ");
199 else
200 fprintf(stderr, "%-10.6f " , holdout_sum_loss_since_last_dump/weighted_holdout_examples_since_last_dump);
201
202 weighted_holdout_examples_since_last_dump = 0;
203 holdout_sum_loss_since_last_dump = 0.0;
204
205 fprintf(stderr, "%8ld %8.1f %s %s %8lu h\n",
206 (long int)example_number,
207 weighted_examples,
208 label_buf,
209 pred_buf,
210 num_features);
211 }
212 else
213 fprintf(stderr, "%-10.6f %-10.6f %8ld %8.1f %s %s %8lu\n",
214 sum_loss/weighted_examples,
215 sum_loss_since_last_dump / (weighted_examples - old_weighted_examples),
216 (long int)example_number,
217 weighted_examples,
218 label_buf,
219 pred_buf,
220 (long unsigned int)num_features);
221 fflush(stderr);
222 update_dump_interval(progress_add, progress_arg);
223 }
224 };
225
226 struct vw {
227 shared_data* sd;
228
229 parser* p;
230 #ifndef _WIN32
231 pthread_t parse_thread;
232 #else
233 HANDLE parse_thread;
234 #endif
235
236 node_socks socks;
237
238 LEARNER::base_learner* l;//the top level learner
239 LEARNER::base_learner* scorer;//a scoring function
240 LEARNER::base_learner* cost_sensitive;//a cost sensitive learning algorithm.
241
242 void learn(example*);
243
244 void (*set_minmax)(shared_data* sd, float label);
245
246 size_t current_pass;
247
248 uint32_t num_bits; // log_2 of the number of features.
249 bool default_bits;
250
251 string data_filename; // was vm["data"]
252
253 bool daemon;
254 size_t num_children;
255
256 bool save_per_pass;
257 float initial_weight;
258 float initial_constant;
259
260 bool bfgs;
261 bool hessian_on;
262
263 bool save_resume;
264 double normalized_sum_norm_x;
265
266 po::options_description opts;
267 po::options_description* new_opts;
268 po::variables_map vm;
269 std::stringstream* file_options;
270 vector<std::string> args;
271
272 void* /*Search::search*/ searchstr;
273
274 uint32_t wpp;
275
276 int stdout_fileno;
277
278 std::string per_feature_regularizer_input;
279 std::string per_feature_regularizer_output;
280 std::string per_feature_regularizer_text;
281
282 float l1_lambda; //the level of l_1 regularization to impose.
283 float l2_lambda; //the level of l_2 regularization to impose.
284 float power_t;//the power on learning rate decay.
285 int reg_mode;
286
287 size_t pass_length;
288 size_t numpasses;
289 size_t passes_complete;
290 size_t parse_mask; // 1 << num_bits -1
291 std::vector<std::string> pairs; // pairs of features to cross.
292 std::vector<std::string> triples; // triples of features to cross.
293 bool ignore_some;
294 bool ignore[256];//a set of namespaces to ignore
295
296 std::vector<std::string> ngram_strings; // pairs of features to cross.
297 std::vector<std::string> skip_strings; // triples of features to cross.
298 uint32_t ngram[256];//ngrams to generate.
299 uint32_t skips[256];//skips in ngrams.
300 std::vector<std::string> limit_strings; // descriptor of feature limits
301 uint32_t limit[256];//count to limit features by
302 uint32_t affix_features[256]; // affixes to generate (up to 8 per namespace)
303 bool spelling_features[256]; // generate spelling features for which namespace
304 vector<feature_dict*> namespace_dictionaries[256]; // each namespace has a list of dictionaries attached to it
305 vector<dictionary_info> read_dictionaries; // which dictionaries have we read?
306
307 bool audit;//should I print lots of debugging information?
308 bool quiet;//Should I suppress progress-printing of updates?
309 bool training;//Should I train if lable data is available?
310 bool active;
311 bool adaptive;//Should I use adaptive individual learning rates?
312 bool normalized_updates; //Should every feature be normalized
313 bool invariant_updates; //Should we use importance aware/safe updates
314 bool random_weights;
315 bool random_positive_weights; // for initialize_regressor w/ new_mf
316 bool add_constant;
317 bool nonormalize;
318 bool do_reset_source;
319 bool holdout_set_off;
320 bool early_terminate;
321 uint32_t holdout_period;
322 uint32_t holdout_after;
323 size_t check_holdout_every_n_passes; // default: 1, but search might want to set it higher if you spend multiple passes learning a single policy
324
325 size_t normalized_idx; //offset idx where the norm is stored (1 or 2 depending on whether adaptive is true)
326
327 uint32_t lda;
328
329 std::string text_regressor_name;
330 std::string inv_hash_regressor_name;
331 std::string span_server;
332
lengthvw333 size_t length () { return ((size_t)1) << num_bits; };
334
335 v_array<LEARNER::base_learner* (*)(vw&)> reduction_stack;
336
337 //Prediction output
338 v_array<int> final_prediction_sink; // set to send global predictions to.
339 int raw_prediction; // file descriptors for text output.
340 size_t unique_id; //unique id for each node in the network, id == 0 means extra io.
341 size_t total; //total number of nodes
342 size_t node; //node id number
343
344 void (*print)(int,float,float,v_array<char>);
345 void (*print_text)(int, string, v_array<char>);
346 loss_function* loss;
347
348 char* program_name;
349
350 bool stdin_off;
351
352 //runtime accounting variables.
353 float initial_t;
354 float eta;//learning rate control.
355 float eta_decay_rate;
356
357 std::string final_regressor_name;
358 regressor reg;
359
360 size_t max_examples; // for TLC
361
362 bool hash_inv;
363 bool print_invert;
364
365 // Set by --progress <arg>
366 bool progress_add; // additive (rather than multiplicative) progress dumps
367 float progress_arg; // next update progress dump multiplier
368
369 std::map< std::string, size_t> name_index_map;
370
371 vw();
372 };
373
374 void print_result(int f, float res, float weight, v_array<char> tag);
375 void binary_print_result(int f, float res, float weight, v_array<char> tag);
376 void noop_mm(shared_data*, float label);
377 void print_lda_result(vw& all, int f, float* res, float weight, v_array<char> tag);
378 void get_prediction(int sock, float& res, float& weight);
379 void compile_gram(vector<string> grams, uint32_t* dest, char* descriptor, bool quiet);
380 void compile_limits(vector<string> limits, uint32_t* dest, bool quiet);
381 int print_tag(std::stringstream& ss, v_array<char> tag);
382 void add_options(vw& all, po::options_description& opts);
383 inline po::options_description_easy_init new_options(vw& all, std::string name = "\0")
384 {
385 all.new_opts = new po::options_description(name);
386 return all.new_opts->add_options();
387 }
388 bool no_new_options(vw& all);
389 bool missing_option(vw& all, bool keep, const char* name, const char* description);
missing_option(vw & all,const char * name,const char * description)390 template <class T> bool missing_option(vw& all, const char* name, const char* description)
391 {
392 new_options(all)(name, po::value<T>(), description);
393 return no_new_options(all);
394 }
missing_option(vw & all,const char * name,const char * description)395 template <class T, bool keep> bool missing_option(vw& all, const char* name,
396 const char* description)
397 {
398 if (missing_option<T>(all, name, description))
399 return true;
400 if (keep)
401 *all.file_options << " --" << name << " " << all.vm[name].as<T>();
402 return false;
403 }
404 void add_options(vw& all);
405