1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD
4 license as described in the file LICENSE.
5  */
6 #pragma once
7 #include <vector>
8 #include <map>
9 #include <stdint.h>
10 #include <cstdio>
11 #include <boost/program_options.hpp>
12 namespace po = boost::program_options;
13 
14 #include "v_array.h"
15 #include "parse_primitives.h"
16 #include "loss_functions.h"
17 #include "comp_io.h"
18 #include "example.h"
19 #include "config.h"
20 #include "learner.h"
21 #include "allreduce.h"
22 #include "v_hashmap.h"
23 
24 struct version_struct {
25   int major;
26   int minor;
27   int rev;
version_structversion_struct28   version_struct(int maj, int min, int rv)
29   {
30     major = maj;
31     minor = min;
32     rev = rv;
33   }
version_structversion_struct34   version_struct(const char* v_str)
35   {
36     from_string(v_str);
37   }
38   void operator=(version_struct v){
39     major = v.major;
40     minor = v.minor;
41     rev = v.rev;
42   }
43   void operator=(const char* v_str){
44     from_string(v_str);
45   }
46   bool operator==(version_struct v){
47     return (major == v.major && minor == v.minor && rev == v.rev);
48   }
49   bool operator==(const char* v_str){
50     version_struct v_tmp(v_str);
51     return (*this == v_tmp);
52   }
53   bool operator!=(version_struct v){
54     return !(*this == v);
55   }
56   bool operator!=(const char* v_str){
57     version_struct v_tmp(v_str);
58     return (*this != v_tmp);
59   }
60   bool operator>=(version_struct v){
61     if(major < v.major) return false;
62     if(major > v.major) return true;
63     if(minor < v.minor) return false;
64     if(minor > v.minor) return true;
65     if(rev >= v.rev ) return true;
66     return false;
67   }
68   bool operator>=(const char* v_str){
69     version_struct v_tmp(v_str);
70     return (*this >= v_tmp);
71   }
72   bool operator>(version_struct v){
73     if(major < v.major) return false;
74     if(major > v.major) return true;
75     if(minor < v.minor) return false;
76     if(minor > v.minor) return true;
77     if(rev > v.rev ) return true;
78     return false;
79   }
80   bool operator>(const char* v_str){
81     version_struct v_tmp(v_str);
82     return (*this > v_tmp);
83   }
84   bool operator<=(version_struct v){
85     return !(*this < v);
86   }
87   bool operator<=(const char* v_str){
88     version_struct v_tmp(v_str);
89     return (*this <= v_tmp);
90   }
91   bool operator<(version_struct v){
92     return !(*this >= v);
93   }
94   bool operator<(const char* v_str){
95     version_struct v_tmp(v_str);
96     return (*this < v_tmp);
97   }
to_stringversion_struct98   std::string to_string() const
99   {
100     char v_str[128];
101     std::sprintf(v_str,"%d.%d.%d",major,minor,rev);
102     std::string s = v_str;
103     return s;
104   }
from_stringversion_struct105   void from_string(const char* str)
106   {
107     std::sscanf(str,"%d.%d.%d",&major,&minor,&rev);
108   }
109 };
110 
111 const version_struct version(PACKAGE_VERSION);
112 
113 typedef float weight;
114 
115 struct regressor {
116   weight* weight_vector;
117   size_t weight_mask; // (stride*(1 << num_bits) -1)
118   uint32_t stride_shift;
119 };
120 
121 typedef v_hashmap< substring, v_array<feature>* > feature_dict;
122 struct dictionary_info {
123   char* name;
124   feature_dict* dict;
125 };
126 
127 struct shared_data {
128   size_t queries;
129 
130   uint64_t example_number;
131   uint64_t total_features;
132 
133   double t;
134   double weighted_examples;
135   double weighted_unlabeled_examples;
136   double old_weighted_examples;
137   double weighted_labels;
138   double sum_loss;
139   double sum_loss_since_last_dump;
140   float dump_interval;// when should I update for the user.
141   double gravity;
142   double contraction;
143   float min_label;//minimum label encountered
144   float max_label;//maximum label encountered
145 
146   //for holdout
147   double weighted_holdout_examples;
148   double weighted_holdout_examples_since_last_dump;
149   double holdout_sum_loss_since_last_dump;
150   double holdout_sum_loss;
151   //for best model selection
152   double holdout_best_loss;
153   double weighted_holdout_examples_since_last_pass;//reserved for best predictor selection
154   double holdout_sum_loss_since_last_pass;
155   size_t holdout_best_pass;
156 
updateshared_data157   void update(bool test_example, float loss, float weight, size_t num_features)
158   {
159     if(test_example)
160       {
161 	weighted_holdout_examples += weight;//test weight seen
162 	weighted_holdout_examples_since_last_dump += weight;
163 	weighted_holdout_examples_since_last_pass += weight;
164 	holdout_sum_loss += loss;
165 	holdout_sum_loss_since_last_dump += loss;
166 	holdout_sum_loss_since_last_pass += loss;//since last pass
167       }
168     else
169       {
170 	weighted_examples += weight;
171 	sum_loss += loss;
172 	sum_loss_since_last_dump += loss;
173 	total_features += num_features;
174 	example_number++;
175       }
176   }
177 
update_dump_intervalshared_data178   inline void update_dump_interval(bool progress_add, float progress_arg) {
179     sum_loss_since_last_dump = 0.0;
180     old_weighted_examples = weighted_examples;
181     if (progress_add)
182       dump_interval = (float)weighted_examples + progress_arg;
183     else
184       dump_interval = (float)weighted_examples * progress_arg;
185   }
186 
print_updateshared_data187   void print_update(bool holdout_set_off, size_t current_pass, char* label_buf, char* pred_buf,
188 		    size_t num_features, bool progress_add, float progress_arg)
189   {
190     if(!holdout_set_off && current_pass >= 1)
191       {
192 	if(holdout_sum_loss == 0. && weighted_holdout_examples == 0.)
193 	  fprintf(stderr, " unknown   ");
194 	else
195 	  fprintf(stderr, "%-10.6f " , holdout_sum_loss/weighted_holdout_examples);
196 
197 	if(holdout_sum_loss_since_last_dump == 0. && weighted_holdout_examples_since_last_dump == 0.)
198 	  fprintf(stderr, " unknown   ");
199 	else
200 	  fprintf(stderr, "%-10.6f " , holdout_sum_loss_since_last_dump/weighted_holdout_examples_since_last_dump);
201 
202 	weighted_holdout_examples_since_last_dump = 0;
203 	holdout_sum_loss_since_last_dump = 0.0;
204 
205 	fprintf(stderr, "%8ld %8.1f   %s %s %8lu h\n",
206 		(long int)example_number,
207 		weighted_examples,
208 		label_buf,
209 		pred_buf,
210 		num_features);
211       }
212     else
213       fprintf(stderr, "%-10.6f %-10.6f %8ld %8.1f   %s %s %8lu\n",
214 	      sum_loss/weighted_examples,
215 	      sum_loss_since_last_dump / (weighted_examples - old_weighted_examples),
216 	      (long int)example_number,
217 	      weighted_examples,
218 	      label_buf,
219 	      pred_buf,
220 	      (long unsigned int)num_features);
221     fflush(stderr);
222     update_dump_interval(progress_add, progress_arg);
223   }
224 };
225 
226 struct vw {
227   shared_data* sd;
228 
229   parser* p;
230 #ifndef _WIN32
231   pthread_t parse_thread;
232 #else
233   HANDLE parse_thread;
234 #endif
235 
236   node_socks socks;
237 
238   LEARNER::base_learner* l;//the top level learner
239   LEARNER::base_learner* scorer;//a scoring function
240   LEARNER::base_learner* cost_sensitive;//a cost sensitive learning algorithm.
241 
242   void learn(example*);
243 
244   void (*set_minmax)(shared_data* sd, float label);
245 
246   size_t current_pass;
247 
248   uint32_t num_bits; // log_2 of the number of features.
249   bool default_bits;
250 
251   string data_filename; // was vm["data"]
252 
253   bool daemon;
254   size_t num_children;
255 
256   bool save_per_pass;
257   float initial_weight;
258   float initial_constant;
259 
260   bool bfgs;
261   bool hessian_on;
262 
263   bool save_resume;
264   double normalized_sum_norm_x;
265 
266   po::options_description opts;
267   po::options_description* new_opts;
268   po::variables_map vm;
269   std::stringstream* file_options;
270   vector<std::string> args;
271 
272   void* /*Search::search*/ searchstr;
273 
274   uint32_t wpp;
275 
276   int stdout_fileno;
277 
278   std::string per_feature_regularizer_input;
279   std::string per_feature_regularizer_output;
280   std::string per_feature_regularizer_text;
281 
282   float l1_lambda; //the level of l_1 regularization to impose.
283   float l2_lambda; //the level of l_2 regularization to impose.
284   float power_t;//the power on learning rate decay.
285   int reg_mode;
286 
287   size_t pass_length;
288   size_t numpasses;
289   size_t passes_complete;
290   size_t parse_mask; // 1 << num_bits -1
291   std::vector<std::string> pairs; // pairs of features to cross.
292   std::vector<std::string> triples; // triples of features to cross.
293   bool ignore_some;
294   bool ignore[256];//a set of namespaces to ignore
295 
296   std::vector<std::string> ngram_strings; // pairs of features to cross.
297   std::vector<std::string> skip_strings; // triples of features to cross.
298   uint32_t ngram[256];//ngrams to generate.
299   uint32_t skips[256];//skips in ngrams.
300   std::vector<std::string> limit_strings; // descriptor of feature limits
301   uint32_t limit[256];//count to limit features by
302   uint32_t affix_features[256]; // affixes to generate (up to 8 per namespace)
303   bool     spelling_features[256]; // generate spelling features for which namespace
304   vector<feature_dict*> namespace_dictionaries[256]; // each namespace has a list of dictionaries attached to it
305   vector<dictionary_info> read_dictionaries; // which dictionaries have we read?
306 
307   bool audit;//should I print lots of debugging information?
308   bool quiet;//Should I suppress progress-printing of updates?
309   bool training;//Should I train if lable data is available?
310   bool active;
311   bool adaptive;//Should I use adaptive individual learning rates?
312   bool normalized_updates; //Should every feature be normalized
313   bool invariant_updates; //Should we use importance aware/safe updates
314   bool random_weights;
315   bool random_positive_weights; // for initialize_regressor w/ new_mf
316   bool add_constant;
317   bool nonormalize;
318   bool do_reset_source;
319   bool holdout_set_off;
320   bool early_terminate;
321   uint32_t holdout_period;
322   uint32_t holdout_after;
323   size_t check_holdout_every_n_passes;  // default: 1, but search might want to set it higher if you spend multiple passes learning a single policy
324 
325   size_t normalized_idx; //offset idx where the norm is stored (1 or 2 depending on whether adaptive is true)
326 
327   uint32_t lda;
328 
329   std::string text_regressor_name;
330   std::string inv_hash_regressor_name;
331   std::string span_server;
332 
lengthvw333   size_t length () { return ((size_t)1) << num_bits; };
334 
335   v_array<LEARNER::base_learner* (*)(vw&)> reduction_stack;
336 
337   //Prediction output
338   v_array<int> final_prediction_sink; // set to send global predictions to.
339   int raw_prediction; // file descriptors for text output.
340   size_t unique_id; //unique id for each node in the network, id == 0 means extra io.
341   size_t total; //total number of nodes
342   size_t node; //node id number
343 
344   void (*print)(int,float,float,v_array<char>);
345   void (*print_text)(int, string, v_array<char>);
346   loss_function* loss;
347 
348   char* program_name;
349 
350   bool stdin_off;
351 
352   //runtime accounting variables.
353   float initial_t;
354   float eta;//learning rate control.
355   float eta_decay_rate;
356 
357   std::string final_regressor_name;
358   regressor reg;
359 
360   size_t max_examples; // for TLC
361 
362   bool hash_inv;
363   bool print_invert;
364 
365   // Set by --progress <arg>
366   bool  progress_add;   // additive (rather than multiplicative) progress dumps
367   float progress_arg;   // next update progress dump multiplier
368 
369   std::map< std::string, size_t> name_index_map;
370 
371   vw();
372 };
373 
374 void print_result(int f, float res, float weight, v_array<char> tag);
375 void binary_print_result(int f, float res, float weight, v_array<char> tag);
376 void noop_mm(shared_data*, float label);
377 void print_lda_result(vw& all, int f, float* res, float weight, v_array<char> tag);
378 void get_prediction(int sock, float& res, float& weight);
379 void compile_gram(vector<string> grams, uint32_t* dest, char* descriptor, bool quiet);
380 void compile_limits(vector<string> limits, uint32_t* dest, bool quiet);
381 int print_tag(std::stringstream& ss, v_array<char> tag);
382 void add_options(vw& all, po::options_description& opts);
383 inline po::options_description_easy_init new_options(vw& all, std::string name = "\0")
384 {
385   all.new_opts = new po::options_description(name);
386   return all.new_opts->add_options();
387 }
388 bool no_new_options(vw& all);
389 bool missing_option(vw& all, bool keep, const char* name, const char* description);
missing_option(vw & all,const char * name,const char * description)390 template <class T> bool missing_option(vw& all, const char* name, const char* description)
391 {
392   new_options(all)(name, po::value<T>(), description);
393   return no_new_options(all);
394 }
missing_option(vw & all,const char * name,const char * description)395 template <class T, bool keep> bool missing_option(vw& all, const char* name,
396 						  const char* description)
397 {
398   if (missing_option<T>(all, name, description))
399     return true;
400   if (keep)
401     *all.file_options << " --" << name << " " << all.vm[name].as<T>();
402   return false;
403 }
404 void add_options(vw& all);
405