1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5 */
6 #include <stdio.h>
7 #include <float.h>
8 #include <iostream>
9 #include <sstream>
10 #include <math.h>
11 #include <assert.h>
12
13 #include "global_data.h"
14 #include "gd.h"
15
16 using namespace std;
17
18 struct global_prediction {
19 float p;
20 float weight;
21 };
22
really_read(int sock,void * in,size_t count)23 size_t really_read(int sock, void* in, size_t count)
24 {
25 char* buf = (char*)in;
26 size_t done = 0;
27 int r = 0;
28 while (done < count)
29 {
30 if ((r =
31 #ifdef _WIN32
32 recv(sock,buf,(unsigned int)(count-done),0)
33 #else
34 read(sock,buf,(unsigned int)(count-done))
35 #endif
36 ) == 0)
37 return 0;
38 else
39 if (r < 0)
40 {
41 cerr << "read(" << sock << "," << count << "-" << done << "): " << strerror(errno) << endl;
42 throw exception();
43 }
44 else
45 {
46 done += r;
47 buf += r;
48 }
49 }
50 return done;
51 }
52
get_prediction(int sock,float & res,float & weight)53 void get_prediction(int sock, float& res, float& weight)
54 {
55 global_prediction p;
56 really_read(sock, &p, sizeof(p));
57 res = p.p;
58 weight = p.weight;
59 }
60
send_prediction(int sock,global_prediction p)61 void send_prediction(int sock, global_prediction p)
62 {
63 if (
64 #ifdef _WIN32
65 send(sock, reinterpret_cast<const char*>(&p), sizeof(p), 0)
66 #else
67 write(sock, &p, sizeof(p))
68 #endif
69 < (int)sizeof(p))
70 {
71 cerr << "send_prediction write(" << sock << "): " << strerror(errno) << endl;
72 throw exception();
73 }
74 }
75
binary_print_result(int f,float res,float weight,v_array<char> tag)76 void binary_print_result(int f, float res, float weight, v_array<char> tag)
77 {
78 if (f >= 0)
79 {
80 global_prediction ps = {res, weight};
81 send_prediction(f, ps);
82 }
83 }
84
print_tag(std::stringstream & ss,v_array<char> tag)85 int print_tag(std::stringstream& ss, v_array<char> tag)
86 {
87 if (tag.begin != tag.end){
88 ss << ' ';
89 ss.write(tag.begin, sizeof(char)*tag.size());
90 }
91 return tag.begin != tag.end;
92 }
93
print_result(int f,float res,float weight,v_array<char> tag)94 void print_result(int f, float res, float weight, v_array<char> tag)
95 {
96 if (f >= 0)
97 {
98 char temp[30];
99 sprintf(temp, "%f", res);
100 std::stringstream ss;
101 ss << temp;
102 print_tag(ss, tag);
103 ss << '\n';
104 ssize_t len = ss.str().size();
105 ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
106 if (t != len)
107 {
108 cerr << "write error" << endl;
109 }
110 }
111 }
112
print_raw_text(int f,string s,v_array<char> tag)113 void print_raw_text(int f, string s, v_array<char> tag)
114 {
115 if (f < 0)
116 return;
117
118 std::stringstream ss;
119 ss << s;
120 print_tag (ss, tag);
121 ss << '\n';
122 ssize_t len = ss.str().size();
123 ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
124 if (t != len)
125 {
126 cerr << "write error" << endl;
127 }
128 }
129
print_lda_result(vw & all,int f,float * res,float weight,v_array<char> tag)130 void print_lda_result(vw& all, int f, float* res, float weight, v_array<char> tag)
131 {
132 if (f >= 0)
133 {
134 std::stringstream ss;
135 char temp[30];
136 for (size_t k = 0; k < all.lda; k++)
137 {
138 sprintf(temp, "%f ", res[k]);
139 ss << temp;
140 }
141 print_tag(ss, tag);
142 ss << '\n';
143 ssize_t len = ss.str().size();
144 ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
145
146 if (t != len)
147 cerr << "write error" << endl;
148 }
149 }
150
set_mm(shared_data * sd,float label)151 void set_mm(shared_data* sd, float label)
152 {
153 sd->min_label = min(sd->min_label, label);
154 if (label != FLT_MAX)
155 sd->max_label = max(sd->max_label, label);
156 }
157
noop_mm(shared_data * sd,float label)158 void noop_mm(shared_data* sd, float label)
159 {}
160
learn(example * ec)161 void vw::learn(example* ec)
162 {
163 this->l->learn(*ec);
164 }
165
compile_gram(vector<string> grams,uint32_t * dest,char * descriptor,bool quiet)166 void compile_gram(vector<string> grams, uint32_t* dest, char* descriptor, bool quiet)
167 {
168 for (size_t i = 0; i < grams.size(); i++)
169 {
170 string ngram = grams[i];
171 if ( isdigit(ngram[0]) )
172 {
173 int n = atoi(ngram.c_str());
174 if (!quiet)
175 cerr << "Generating " << n << "-" << descriptor << " for all namespaces." << endl;
176 for (size_t j = 0; j < 256; j++)
177 dest[j] = n;
178 }
179 else if ( ngram.size() == 1)
180 cout << "You must specify the namespace index before the n" << endl;
181 else {
182 int n = atoi(ngram.c_str()+1);
183 dest[(uint32_t)ngram[0]] = n;
184 if (!quiet)
185 cerr << "Generating " << n << "-" << descriptor << " for " << ngram[0] << " namespaces." << endl;
186 }
187 }
188 }
189
compile_limits(vector<string> limits,uint32_t * dest,bool quiet)190 void compile_limits(vector<string> limits, uint32_t* dest, bool quiet)
191 {
192 for (size_t i = 0; i < limits.size(); i++)
193 {
194 string limit = limits[i];
195 if ( isdigit(limit[0]) )
196 {
197 int n = atoi(limit.c_str());
198 if (!quiet)
199 cerr << "limiting to " << n << "features for each namespace." << endl;
200 for (size_t j = 0; j < 256; j++)
201 dest[j] = n;
202 }
203 else if ( limit.size() == 1)
204 cout << "You must specify the namespace index before the n" << endl;
205 else {
206 int n = atoi(limit.c_str()+1);
207 dest[(uint32_t)limit[0]] = n;
208 if (!quiet)
209 cerr << "limiting to " << n << " for namespaces " << limit[0] << endl;
210 }
211 }
212 }
213
add_options(vw & all,po::options_description & opts)214 void add_options(vw& all, po::options_description& opts)
215 {
216 all.opts.add(opts);
217 po::variables_map new_vm;
218 //parse local opts once for notifications.
219 po::parsed_options parsed = po::command_line_parser(all.args).
220 style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
221 options(opts).allow_unregistered().run();
222 po::store(parsed, new_vm);
223 po::notify(new_vm);
224
225 for (po::variables_map::iterator it=new_vm.begin(); it!=new_vm.end(); ++it)
226 all.vm.insert(*it);
227 }
228
add_options(vw & all)229 void add_options(vw& all)
230 {
231 add_options(all, *all.new_opts);
232 delete all.new_opts;
233 }
234
no_new_options(vw & all)235 bool no_new_options(vw& all)
236 {
237 //parse local opts once for notifications.
238 po::parsed_options parsed = po::command_line_parser(all.args).
239 style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
240 options(*all.new_opts).allow_unregistered().run();
241 po::variables_map new_vm;
242 po::store(parsed, new_vm);
243 all.opts.add(*all.new_opts);
244 delete all.new_opts;
245 for (po::variables_map::iterator it=new_vm.begin(); it!=new_vm.end(); ++it)
246 all.vm.insert(*it);
247
248 if (new_vm.size() == 0) // required are missing;
249 return true;
250 else
251 return false;
252 }
253
missing_option(vw & all,bool keep,const char * name,const char * description)254 bool missing_option(vw& all, bool keep, const char* name, const char* description)
255 {
256 new_options(all)(name,description);
257 if (no_new_options(all))
258 return true;
259 if (keep)
260 *all.file_options << " --" << name;
261 return false;
262 }
263
vw()264 vw::vw()
265 {
266 sd = &calloc_or_die<shared_data>();
267 sd->dump_interval = 1.; // next update progress dump
268 sd->contraction = 1.;
269 sd->max_label = 1.;
270 sd->min_label = 0.;
271
272 p = new_parser();
273 p->emptylines_separate_examples = false;
274 p->lp = simple_label;
275
276 reg_mode = 0;
277 current_pass = 0;
278 reduction_stack=v_init<LEARNER::base_learner* (*)(vw&)>();
279
280 data_filename = "";
281
282 file_options = new std::stringstream;
283
284 bfgs = false;
285 hessian_on = false;
286 active = false;
287 reg.stride_shift = 0;
288 num_bits = 18;
289 default_bits = true;
290 daemon = false;
291 num_children = 10;
292 span_server = "";
293 save_resume = false;
294
295 random_positive_weights = false;
296
297 set_minmax = set_mm;
298
299 power_t = 0.5;
300 eta = 0.5; //default learning rate for normalized adaptive updates, this is switched to 10 by default for the other updates (see parse_args.cc)
301 numpasses = 1;
302
303 final_prediction_sink.begin = final_prediction_sink.end=final_prediction_sink.end_array = NULL;
304 raw_prediction = -1;
305 print = print_result;
306 print_text = print_raw_text;
307 lda = 0;
308 random_weights = false;
309 per_feature_regularizer_input = "";
310 per_feature_regularizer_output = "";
311 per_feature_regularizer_text = "";
312
313 #ifdef _WIN32
314 stdout_fileno = _fileno(stdout);
315 #else
316 stdout_fileno = fileno(stdout);
317 #endif
318
319 searchstr = NULL;
320
321 nonormalize = false;
322 l1_lambda = 0.0;
323 l2_lambda = 0.0;
324
325 eta_decay_rate = 1.0;
326 initial_weight = 0.0;
327 initial_constant = 0.0;
328
329 unique_id = 0;
330 total = 1;
331 node = 0;
332
333 for (size_t i = 0; i < 256; i++)
334 {
335 ngram[i] = 0;
336 skips[i] = 0;
337 limit[i] = INT_MAX;
338 affix_features[i] = 0;
339 spelling_features[i] = 0;
340 }
341
342 //by default use invariant normalized adaptive updates
343 adaptive = true;
344 normalized_updates = true;
345 invariant_updates = true;
346
347 normalized_idx = 2;
348
349 add_constant = true;
350 audit = false;
351 reg.weight_vector = NULL;
352 pass_length = (size_t)-1;
353 passes_complete = 0;
354
355 save_per_pass = false;
356
357 stdin_off = false;
358 do_reset_source = false;
359 holdout_set_off = true;
360 holdout_period = 10;
361 holdout_after = 0;
362 check_holdout_every_n_passes = 1;
363 early_terminate = false;
364
365 max_examples = (size_t)-1;
366
367 hash_inv = false;
368 print_invert = false;
369
370 // Set by the '--progress <arg>' option and affect sd->dump_interval
371 progress_add = false; // default is multiplicative progress dumps
372 progress_arg = 2.0; // next update progress dump multiplier
373 }
374
375