1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <stdio.h>
7 #include <float.h>
8 #include <iostream>
9 #include <sstream>
10 #include <math.h>
11 #include <assert.h>
12 
13 #include "global_data.h"
14 #include "gd.h"
15 
16 using namespace std;
17 
18 struct global_prediction {
19   float p;
20   float weight;
21 };
22 
really_read(int sock,void * in,size_t count)23 size_t really_read(int sock, void* in, size_t count)
24 {
25   char* buf = (char*)in;
26   size_t done = 0;
27   int r = 0;
28   while (done < count)
29     {
30       if ((r =
31 #ifdef _WIN32
32 		  recv(sock,buf,(unsigned int)(count-done),0)
33 #else
34 		  read(sock,buf,(unsigned int)(count-done))
35 #endif
36 		  ) == 0)
37 	return 0;
38       else
39 	if (r < 0)
40 	  {
41 	    cerr << "read(" << sock << "," << count << "-" << done << "): " << strerror(errno) << endl;
42 	    throw exception();
43 	  }
44 	else
45 	  {
46 	    done += r;
47 	    buf += r;
48 	  }
49     }
50   return done;
51 }
52 
get_prediction(int sock,float & res,float & weight)53 void get_prediction(int sock, float& res, float& weight)
54 {
55   global_prediction p;
56   really_read(sock, &p, sizeof(p));
57   res = p.p;
58   weight = p.weight;
59 }
60 
send_prediction(int sock,global_prediction p)61 void send_prediction(int sock, global_prediction p)
62 {
63   if (
64 #ifdef _WIN32
65 	  send(sock, reinterpret_cast<const char*>(&p), sizeof(p), 0)
66 #else
67 	  write(sock, &p, sizeof(p))
68 #endif
69 	  < (int)sizeof(p))
70     {
71       cerr << "send_prediction write(" << sock << "): " << strerror(errno) << endl;
72       throw exception();
73     }
74 }
75 
binary_print_result(int f,float res,float weight,v_array<char> tag)76 void binary_print_result(int f, float res, float weight, v_array<char> tag)
77 {
78   if (f >= 0)
79     {
80       global_prediction ps = {res, weight};
81       send_prediction(f, ps);
82     }
83 }
84 
print_tag(std::stringstream & ss,v_array<char> tag)85 int print_tag(std::stringstream& ss, v_array<char> tag)
86 {
87   if (tag.begin != tag.end){
88     ss << ' ';
89     ss.write(tag.begin, sizeof(char)*tag.size());
90   }
91   return tag.begin != tag.end;
92 }
93 
print_result(int f,float res,float weight,v_array<char> tag)94 void print_result(int f, float res, float weight, v_array<char> tag)
95 {
96   if (f >= 0)
97     {
98       char temp[30];
99       sprintf(temp, "%f", res);
100       std::stringstream ss;
101       ss << temp;
102       print_tag(ss, tag);
103       ss << '\n';
104       ssize_t len = ss.str().size();
105       ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
106       if (t != len)
107         {
108           cerr << "write error" << endl;
109         }
110     }
111 }
112 
print_raw_text(int f,string s,v_array<char> tag)113 void print_raw_text(int f, string s, v_array<char> tag)
114 {
115   if (f < 0)
116     return;
117 
118   std::stringstream ss;
119   ss << s;
120   print_tag (ss, tag);
121   ss << '\n';
122   ssize_t len = ss.str().size();
123   ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
124   if (t != len)
125     {
126       cerr << "write error" << endl;
127     }
128 }
129 
print_lda_result(vw & all,int f,float * res,float weight,v_array<char> tag)130 void print_lda_result(vw& all, int f, float* res, float weight, v_array<char> tag)
131 {
132   if (f >= 0)
133     {
134       std::stringstream ss;
135       char temp[30];
136       for (size_t k = 0; k < all.lda; k++)
137 	{
138 	  sprintf(temp, "%f ", res[k]);
139           ss << temp;
140 	}
141       print_tag(ss, tag);
142       ss << '\n';
143       ssize_t len = ss.str().size();
144       ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
145 
146       if (t != len)
147 	cerr << "write error" << endl;
148     }
149 }
150 
set_mm(shared_data * sd,float label)151 void set_mm(shared_data* sd, float label)
152 {
153   sd->min_label = min(sd->min_label, label);
154   if (label != FLT_MAX)
155     sd->max_label = max(sd->max_label, label);
156 }
157 
noop_mm(shared_data * sd,float label)158 void noop_mm(shared_data* sd, float label)
159 {}
160 
learn(example * ec)161 void vw::learn(example* ec)
162 {
163   this->l->learn(*ec);
164 }
165 
compile_gram(vector<string> grams,uint32_t * dest,char * descriptor,bool quiet)166 void compile_gram(vector<string> grams, uint32_t* dest, char* descriptor, bool quiet)
167 {
168   for (size_t i = 0; i < grams.size(); i++)
169     {
170       string ngram = grams[i];
171       if ( isdigit(ngram[0]) )
172 	{
173 	  int n = atoi(ngram.c_str());
174 	  if (!quiet)
175 	    cerr << "Generating " << n << "-" << descriptor << " for all namespaces." << endl;
176 	  for (size_t j = 0; j < 256; j++)
177 	    dest[j] = n;
178 	}
179       else if ( ngram.size() == 1)
180 	cout << "You must specify the namespace index before the n" << endl;
181       else {
182 	int n = atoi(ngram.c_str()+1);
183 	dest[(uint32_t)ngram[0]] = n;
184 	if (!quiet)
185 	  cerr << "Generating " << n << "-" << descriptor << " for " << ngram[0] << " namespaces." << endl;
186       }
187     }
188 }
189 
compile_limits(vector<string> limits,uint32_t * dest,bool quiet)190 void compile_limits(vector<string> limits, uint32_t* dest, bool quiet)
191 {
192   for (size_t i = 0; i < limits.size(); i++)
193     {
194       string limit = limits[i];
195       if ( isdigit(limit[0]) )
196 	{
197 	  int n = atoi(limit.c_str());
198 	  if (!quiet)
199 	    cerr << "limiting to " << n << "features for each namespace." << endl;
200 	  for (size_t j = 0; j < 256; j++)
201 	    dest[j] = n;
202 	}
203       else if ( limit.size() == 1)
204 	cout << "You must specify the namespace index before the n" << endl;
205       else {
206 	int n = atoi(limit.c_str()+1);
207 	dest[(uint32_t)limit[0]] = n;
208 	if (!quiet)
209 	  cerr << "limiting to " << n << " for namespaces " << limit[0] << endl;
210       }
211     }
212 }
213 
add_options(vw & all,po::options_description & opts)214 void add_options(vw& all, po::options_description& opts)
215 {
216   all.opts.add(opts);
217   po::variables_map new_vm;
218   //parse local opts once for notifications.
219   po::parsed_options parsed = po::command_line_parser(all.args).
220     style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
221     options(opts).allow_unregistered().run();
222   po::store(parsed, new_vm);
223   po::notify(new_vm);
224 
225   for (po::variables_map::iterator it=new_vm.begin(); it!=new_vm.end(); ++it)
226     all.vm.insert(*it);
227 }
228 
add_options(vw & all)229 void add_options(vw& all)
230 {
231   add_options(all, *all.new_opts);
232   delete all.new_opts;
233 }
234 
no_new_options(vw & all)235 bool no_new_options(vw& all)
236 {
237   //parse local opts once for notifications.
238   po::parsed_options parsed = po::command_line_parser(all.args).
239     style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
240     options(*all.new_opts).allow_unregistered().run();
241   po::variables_map new_vm;
242   po::store(parsed, new_vm);
243   all.opts.add(*all.new_opts);
244   delete all.new_opts;
245   for (po::variables_map::iterator it=new_vm.begin(); it!=new_vm.end(); ++it)
246     all.vm.insert(*it);
247 
248   if (new_vm.size() == 0) // required are missing;
249     return true;
250   else
251     return false;
252 }
253 
missing_option(vw & all,bool keep,const char * name,const char * description)254 bool missing_option(vw& all, bool keep, const char* name, const char* description)
255 {
256   new_options(all)(name,description);
257   if (no_new_options(all))
258     return true;
259   if (keep)
260     *all.file_options << " --" << name;
261   return false;
262 }
263 
vw()264 vw::vw()
265 {
266   sd = &calloc_or_die<shared_data>();
267   sd->dump_interval = 1.;   // next update progress dump
268   sd->contraction = 1.;
269   sd->max_label = 1.;
270   sd->min_label = 0.;
271 
272   p = new_parser();
273   p->emptylines_separate_examples = false;
274   p->lp = simple_label;
275 
276   reg_mode = 0;
277   current_pass = 0;
278   reduction_stack=v_init<LEARNER::base_learner* (*)(vw&)>();
279 
280   data_filename = "";
281 
282   file_options = new std::stringstream;
283 
284   bfgs = false;
285   hessian_on = false;
286   active = false;
287   reg.stride_shift = 0;
288   num_bits = 18;
289   default_bits = true;
290   daemon = false;
291   num_children = 10;
292   span_server = "";
293   save_resume = false;
294 
295   random_positive_weights = false;
296 
297   set_minmax = set_mm;
298 
299   power_t = 0.5;
300   eta = 0.5; //default learning rate for normalized adaptive updates, this is switched to 10 by default for the other updates (see parse_args.cc)
301   numpasses = 1;
302 
303   final_prediction_sink.begin = final_prediction_sink.end=final_prediction_sink.end_array = NULL;
304   raw_prediction = -1;
305   print = print_result;
306   print_text = print_raw_text;
307   lda = 0;
308   random_weights = false;
309   per_feature_regularizer_input = "";
310   per_feature_regularizer_output = "";
311   per_feature_regularizer_text = "";
312 
313   #ifdef _WIN32
314   stdout_fileno = _fileno(stdout);
315   #else
316   stdout_fileno = fileno(stdout);
317   #endif
318 
319   searchstr = NULL;
320 
321   nonormalize = false;
322   l1_lambda = 0.0;
323   l2_lambda = 0.0;
324 
325   eta_decay_rate = 1.0;
326   initial_weight = 0.0;
327   initial_constant = 0.0;
328 
329   unique_id = 0;
330   total = 1;
331   node = 0;
332 
333   for (size_t i = 0; i < 256; i++)
334     {
335       ngram[i] = 0;
336       skips[i] = 0;
337       limit[i] = INT_MAX;
338       affix_features[i] = 0;
339       spelling_features[i] = 0;
340     }
341 
342   //by default use invariant normalized adaptive updates
343   adaptive = true;
344   normalized_updates = true;
345   invariant_updates = true;
346 
347   normalized_idx = 2;
348 
349   add_constant = true;
350   audit = false;
351   reg.weight_vector = NULL;
352   pass_length = (size_t)-1;
353   passes_complete = 0;
354 
355   save_per_pass = false;
356 
357   stdin_off = false;
358   do_reset_source = false;
359   holdout_set_off = true;
360   holdout_period = 10;
361   holdout_after = 0;
362   check_holdout_every_n_passes = 1;
363   early_terminate = false;
364 
365   max_examples = (size_t)-1;
366 
367   hash_inv = false;
368   print_invert = false;
369 
370   // Set by the '--progress <arg>' option and affect sd->dump_interval
371   progress_add = false;   // default is multiplicative progress dumps
372   progress_arg = 2.0;     // next update progress dump multiplier
373 }
374 
375