1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <float.h>
7 #ifdef _WIN32
8 #include <WinSock2.h>
9 #else
10 #include <netdb.h>
11 #endif
12 
13 #if defined(__SSE2__) && !defined(VW_LDA_NO_SSE)
14 #include <xmmintrin.h>
15 #endif
16 
17 #include "gd.h"
18 #include "accumulate.h"
19 #include "reductions.h"
20 #include "vw.h"
21 
22 using namespace std;
23 using namespace LEARNER;
24 //todo:
25 //4. Factor various state out of vw&
26 namespace GD
27 {
28   struct gd{
29     //double normalized_sum_norm_x;
30     double total_weight;
31     size_t no_win_counter;
32     size_t early_stop_thres;
33     float initial_constant;
34     float neg_norm_power;
35     float neg_power_t;
36     float update_multiplier;
37     void (*predict)(gd&, base_learner&, example&);
38     void (*learn)(gd&, base_learner&, example&);
39     void (*update)(gd&, base_learner&, example&);
40 
41     vw* all; //parallel, features, parameters
42   };
43 
44   void sync_weights(vw& all);
45 
InvSqrt(float x)46   float InvSqrt(float x){
47     float xhalf = 0.5f * x;
48     int i = *(int*)&x; // store floating-point bits in integer
49     i = 0x5f3759d5 - (i >> 1); // initial guess for Newton's method
50     x = *(float*)&i; // convert new bits into float
51     x = x*(1.5f - xhalf*x*x); // One round of Newton's method
52     return x;
53   }
54 
55   template<bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
update_feature(float & update,float x,float & fw)56   inline void update_feature(float& update, float x, float& fw)
57   {
58     weight* w = &fw;
59     if(feature_mask_off || fw != 0.)
60       {
61 	if (spare != 0)
62 	  x *= w[spare];
63 	w[0] += update * x;
64       }
65   }
66 
67   //this deals with few nonzero features vs. all nonzero features issues.
68   template<bool sqrt_rate, size_t adaptive, size_t normalized>
average_update(gd & g,float update)69   float average_update(gd& g, float update)
70   {
71     if (normalized) {
72       if (sqrt_rate)
73 	{
74 	  float avg_norm = (float) g.total_weight / (float) g.all->normalized_sum_norm_x;
75 	  if (adaptive)
76 	    return sqrt(avg_norm);
77 	  else
78 	    return avg_norm;
79 	}
80       else
81 	return powf( (float) g.all->normalized_sum_norm_x / (float) g.total_weight, g.neg_norm_power);
82     }
83     return 1.f;
84   }
85 
86   template<bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
train(gd & g,example & ec,float update)87   void train(gd& g, example& ec, float update)
88   {
89     if (normalized)
90       update *= g.update_multiplier;
91 
92     foreach_feature<float, update_feature<sqrt_rate, feature_mask_off, adaptive, normalized, spare> >(*g.all, ec, update);
93   }
94 
end_pass(gd & g)95   void end_pass(gd& g)
96   {
97     vw& all = *g.all;
98 
99     sync_weights(all);
100     if(all.span_server != "") {
101       if(all.adaptive)
102 	accumulate_weighted_avg(all, all.span_server, all.reg);
103       else
104         accumulate_avg(all, all.span_server, all.reg, 0);
105     }
106 
107     all.eta *= all.eta_decay_rate;
108     if (all.save_per_pass)
109       save_predictor(all, all.final_regressor_name, all.current_pass);
110 
111     all.current_pass++;
112 
113     if(!all.holdout_set_off)
114       {
115         if(summarize_holdout_set(all, g.no_win_counter))
116           finalize_regressor(all, all.final_regressor_name);
117         if((g.early_stop_thres == g.no_win_counter) &&
118            ((all.check_holdout_every_n_passes <= 1) ||
119             ((all.current_pass % all.check_holdout_every_n_passes) == 0)))
120 	  set_done(all);
121       }
122   }
123 
124 struct string_value {
125   float v;
126   string s;
127   friend bool operator<(const string_value& first, const string_value& second);
128 };
129 
sign(float w)130  inline float sign(float w){ if (w < 0.) return -1.; else  return 1.;}
131 
trunc_weight(const float w,const float gravity)132  inline float trunc_weight(const float w, const float gravity){
133    return (gravity < fabsf(w)) ? w - sign(w) * gravity : 0.f;
134  }
135 
operator <(const string_value & first,const string_value & second)136 bool operator<(const string_value& first, const string_value& second)
137 {
138   return fabs(first.v) > fabs(second.v);
139 }
140 
141 #include <algorithm>
142 
audit_feature(vw & all,feature * f,audit_data * a,vector<string_value> & results,string prepend,string & ns_pre,size_t offset=0,float mult=1)143   void audit_feature(vw& all, feature* f, audit_data* a, vector<string_value>& results, string prepend, string& ns_pre, size_t offset = 0, float mult = 1)
144 {
145   ostringstream tempstream;
146   size_t index = (f->weight_index + offset) & all.reg.weight_mask;
147   weight* weights = all.reg.weight_vector;
148   size_t stride_shift = all.reg.stride_shift;
149 
150   if(all.audit) tempstream << prepend;
151 
152   string tmp = "";
153 
154   if (a != NULL){
155     tmp += a->space;
156     tmp += '^';
157     tmp += a->feature;
158   }
159 
160   if (a != NULL && all.audit){
161     tempstream << tmp << ':';
162   }
163   else 	if ( index == ((( (constant << stride_shift) * all.wpp + offset)&all.reg.weight_mask)) && all.audit){
164     tempstream << "Constant:";
165   }
166   if(all.audit){
167     tempstream << ((index >> stride_shift) & all.parse_mask) << ':' << mult*f->x;
168     tempstream  << ':' << trunc_weight(weights[index], (float)all.sd->gravity) * (float)all.sd->contraction;
169   }
170   if(all.current_pass == 0 && all.inv_hash_regressor_name != ""){ //for invert_hash
171     if ( index == (((constant << stride_shift) * all.wpp + offset )& all.reg.weight_mask))
172       tmp = "Constant";
173 
174     ostringstream convert;
175     convert << ((index >>stride_shift) & all.parse_mask);
176     tmp = ns_pre + tmp + ":"+ convert.str();
177 
178     if(!all.name_index_map.count(tmp)){
179       all.name_index_map.insert(std::map< std::string, size_t>::value_type(tmp, ((index >> stride_shift) & all.parse_mask)));
180     }
181   }
182 
183   if(all.adaptive && all.audit)
184     tempstream << '@' << weights[index+1];
185   string_value sv = {weights[index]*f->x, tempstream.str()};
186   results.push_back(sv);
187 }
188 
audit_features(vw & all,v_array<feature> & fs,v_array<audit_data> & as,vector<string_value> & results,string prepend,string & ns_pre,size_t offset=0,float mult=1)189   void audit_features(vw& all, v_array<feature>& fs, v_array<audit_data>& as, vector<string_value>& results, string prepend, string& ns_pre, size_t offset = 0, float mult = 1)
190 {
191   for (size_t j = 0; j< fs.size(); j++)
192     if (as.begin != as.end)
193       audit_feature(all, & fs[j], & as[j], results, prepend, ns_pre, offset, mult);
194     else
195       audit_feature(all, & fs[j], NULL, results, prepend, ns_pre, offset, mult);
196 }
197 
audit_quad(vw & all,feature & left_feature,audit_data * left_audit,v_array<feature> & right_features,v_array<audit_data> & audit_right,vector<string_value> & results,string & ns_pre,uint32_t offset=0)198 void audit_quad(vw& all, feature& left_feature, audit_data* left_audit, v_array<feature> &right_features, v_array<audit_data> &audit_right, vector<string_value>& results, string& ns_pre, uint32_t offset = 0)
199 {
200   size_t halfhash = quadratic_constant * (left_feature.weight_index + offset);
201 
202   ostringstream tempstream;
203   if (audit_right.size() != 0 && left_audit && all.audit)
204     tempstream << left_audit->space << '^' << left_audit->feature << '^';
205   string prepend = tempstream.str();
206 
207   if(all.current_pass == 0 && audit_right.size() != 0 && left_audit)//for invert_hash
208   {
209     ns_pre = left_audit->space;
210     ns_pre = ns_pre + '^' + left_audit->feature + '^';
211   }
212 
213   audit_features(all, right_features, audit_right, results, prepend, ns_pre, halfhash + offset, left_audit ? left_audit->x : 1);
214 }
215 
audit_triple(vw & all,feature & f0,audit_data * f0_audit,feature & f1,audit_data * f1_audit,v_array<feature> & right_features,v_array<audit_data> & audit_right,vector<string_value> & results,string & ns_pre,uint32_t offset=0)216 void audit_triple(vw& all, feature& f0, audit_data* f0_audit, feature& f1, audit_data* f1_audit,
217 		  v_array<feature> &right_features, v_array<audit_data> &audit_right, vector<string_value>& results, string& ns_pre, uint32_t offset = 0)
218 {
219   size_t halfhash = cubic_constant2 * (cubic_constant * (f0.weight_index + offset) + f1.weight_index + offset);
220 
221   ostringstream tempstream;
222   if (audit_right.size() > 0 && f0_audit && f1_audit && all.audit)
223     tempstream << f0_audit->space << '^' << f0_audit->feature << '^'
224 	       << f1_audit->space << '^' << f1_audit->feature << '^';
225   string prepend = tempstream.str();
226 
227   if(all.current_pass == 0 && audit_right.size() != 0 && f0_audit && f1_audit)//for invert_hash
228   {
229     ns_pre = f0_audit->space;
230     ns_pre = ns_pre + '^' + f0_audit->feature + '^' + f1_audit->space + '^' + f1_audit->feature + '^';
231   }
232   audit_features(all, right_features, audit_right, results, prepend, ns_pre, halfhash + offset);
233 }
234 
print_features(vw & all,example & ec)235 void print_features(vw& all, example& ec)
236 {
237   weight* weights = all.reg.weight_vector;
238 
239   if (all.lda > 0)
240     {
241       size_t count = 0;
242       for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++)
243 	count += ec.atomics[*i].size();
244       for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++)
245 	for (audit_data *f = ec.audit_features[*i].begin; f != ec.audit_features[*i].end; f++)
246 	  {
247 	    cout << '\t' << f->space << '^' << f->feature << ':' << ((f->weight_index >> all.reg.stride_shift) & all.parse_mask) << ':' << f->x;
248 	    for (size_t k = 0; k < all.lda; k++)
249 	      cout << ':' << weights[(f->weight_index+k) & all.reg.weight_mask];
250 	  }
251       cout << " total of " << count << " features." << endl;
252     }
253   else
254     {
255       vector<string_value> features;
256       string empty;
257       string ns_pre;
258 
259       for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++){
260         ns_pre = "";
261 	audit_features(all, ec.atomics[*i], ec.audit_features[*i], features, empty, ns_pre, ec.ft_offset);
262         ns_pre = "";
263       }
264       for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++)
265 	{
266 	  unsigned char fst = (*i)[0];
267 	  unsigned char snd = (*i)[1];
268 	  for (size_t j = 0; j < ec.atomics[fst].size(); j++)
269 	    {
270 	      audit_data* a = NULL;
271 	      if (ec.audit_features[fst].size() > 0)
272 		a = & ec.audit_features[fst][j];
273 	      audit_quad(all, ec.atomics[fst][j], a, ec.atomics[snd], ec.audit_features[snd], features, ns_pre);
274 	    }
275 	}
276 
277       for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++)
278 	{
279 	  unsigned char fst = (*i)[0];
280 	  unsigned char snd = (*i)[1];
281 	  unsigned char trd = (*i)[2];
282 	  for (size_t j = 0; j < ec.atomics[fst].size(); j++)
283 	    {
284 	      audit_data* a1 = NULL;
285 	      if (ec.audit_features[fst].size() > 0)
286 		a1 = & ec.audit_features[fst][j];
287 	      for (size_t k = 0; k < ec.atomics[snd].size(); k++)
288 		{
289 		  audit_data* a2 = NULL;
290 		  if (ec.audit_features[snd].size() > 0)
291 		    a2 = & ec.audit_features[snd][k];
292 		  audit_triple(all, ec.atomics[fst][j], a1, ec.atomics[snd][k], a2, ec.atomics[trd], ec.audit_features[trd], features, ns_pre);
293 		}
294 	    }
295 	}
296 
297       sort(features.begin(),features.end());
298       if(all.audit){
299         for (vector<string_value>::iterator sv = features.begin(); sv!= features.end(); sv++)
300 	  cout << '\t' << (*sv).s;
301         cout << endl;
302       }
303     }
304 }
305 
print_audit_features(vw & all,example & ec)306 void print_audit_features(vw& all, example& ec)
307 {
308   if(all.audit)
309     print_result(all.stdout_fileno,ec.pred.scalar,-1,ec.tag);
310   fflush(stdout);
311   print_features(all, ec);
312 }
313 
finalize_prediction(shared_data * sd,float ret)314 float finalize_prediction(shared_data* sd, float ret)
315 {
316   if ( nanpattern(ret))
317     {
318       cerr << "NAN prediction in example " << sd->example_number + 1 << ", forcing 0.0" << endl;
319       return 0.;
320     }
321   if ( ret > sd->max_label )
322     return (float)sd->max_label;
323   if (ret < sd->min_label)
324     return (float)sd->min_label;
325   return ret;
326 }
327 
328  struct trunc_data {
329    float prediction;
330    float gravity;
331  };
332 
vec_add_trunc(trunc_data & p,const float fx,float & fw)333  inline void vec_add_trunc(trunc_data& p, const float fx, float& fw) {
334    p.prediction += trunc_weight(fw, p.gravity) * fx;
335  }
336 
trunc_predict(vw & all,example & ec,double gravity)337  inline float trunc_predict(vw& all, example& ec, double gravity)
338  {
339    trunc_data temp = {ec.l.simple.initial, (float)gravity};
340    foreach_feature<trunc_data, vec_add_trunc>(all, ec, temp);
341    return temp.prediction;
342  }
343 
344 template<bool l1, bool audit>
predict(gd & g,base_learner & base,example & ec)345 void predict(gd& g, base_learner& base, example& ec)
346 {
347   vw& all = *g.all;
348 
349   if (l1)
350     ec.partial_prediction = trunc_predict(all, ec, all.sd->gravity);
351   else
352     ec.partial_prediction = inline_predict(all, ec);
353 
354   ec.partial_prediction *= (float)all.sd->contraction;
355   ec.pred.scalar = finalize_prediction(all.sd, ec.partial_prediction);
356   if (audit)
357     print_audit_features(all, ec);
358 }
359 
360   struct power_data {
361     float minus_power_t;
362     float neg_norm_power;
363   };
364 
365   template<bool sqrt_rate, size_t adaptive, size_t normalized>
compute_rate_decay(power_data & s,float & fw)366   inline float compute_rate_decay(power_data& s, float& fw)
367   {
368     weight* w = &fw;
369     float rate_decay = 1.f;
370     if(adaptive) {
371       if (sqrt_rate)
372 	{
373 #if defined(__SSE2__) && !defined(VW_LDA_NO_SSE)
374 	  __m128 eta = _mm_load_ss(&w[adaptive]);
375 	  eta = _mm_rsqrt_ss(eta);
376 	  _mm_store_ss(&rate_decay, eta);
377 #else
378 	  rate_decay = InvSqrt(w[adaptive]);
379 #endif
380 	}
381       else
382 	rate_decay = powf(w[adaptive],s.minus_power_t);
383     }
384     if(normalized) {
385       if (sqrt_rate)
386 	{
387 	  float inv_norm = 1.f / w[normalized];
388 	  if (adaptive)
389 	    rate_decay *= inv_norm;
390 	  else
391 	    rate_decay *= inv_norm*inv_norm;
392 	}
393       else
394 	rate_decay *= powf(w[normalized]*w[normalized], s.neg_norm_power);
395     }
396     return rate_decay;
397   }
398 
399   struct norm_data {
400     float grad_squared;
401     float pred_per_update;
402     float norm_x;
403     power_data pd;
404   };
405 
406 template<bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
pred_per_update_feature(norm_data & nd,float x,float & fw)407 inline void pred_per_update_feature(norm_data& nd, float x, float& fw) {
408   if(feature_mask_off || fw != 0.){
409     weight* w = &fw;
410     float x2 = x * x;
411     if(adaptive)
412       w[adaptive] += nd.grad_squared * x2;
413     if(normalized) {
414       float x_abs = fabsf(x);
415       if( x_abs > w[normalized] ) {// new scale discovered
416 	if( w[normalized] > 0. ) {//If the normalizer is > 0 then rescale the weight so it's as if the new scale was the old scale.
417 	  if (sqrt_rate) {
418 	    float rescale = w[normalized]/x_abs;
419 	    w[0] *= (adaptive ? rescale : rescale*rescale);
420 	  }
421 	  else {
422 	    float rescale = x_abs/w[normalized];
423 	    w[0] *= powf(rescale*rescale, nd.pd.neg_norm_power);
424 	  }
425 	}
426 	w[normalized] = x_abs;
427       }
428       nd.norm_x += x2 / (w[normalized] * w[normalized]);
429     }
430     w[spare] = compute_rate_decay<sqrt_rate, adaptive, normalized>(nd.pd, fw);
431 
432     nd.pred_per_update += x2 * w[spare];
433   }
434 }
435 
436 template<bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
get_pred_per_update(gd & g,example & ec)437   float get_pred_per_update(gd& g, example& ec)
438   {//We must traverse the features in _precisely_ the same order as during training.
439     label_data& ld = ec.l.simple;
440     vw& all = *g.all;
441     float grad_squared = all.loss->getSquareGrad(ec.pred.scalar, ld.label) * ld.weight;
442     if (grad_squared == 0) return 1.;
443 
444     norm_data nd = {grad_squared, 0., 0., {g.neg_power_t, g.neg_norm_power}};
445 
446     foreach_feature<norm_data,pred_per_update_feature<sqrt_rate, feature_mask_off, adaptive, normalized, spare> >(all, ec, nd);
447 
448     if(normalized) {
449       g.all->normalized_sum_norm_x += ld.weight * nd.norm_x;
450       g.total_weight += ld.weight;
451 
452       g.update_multiplier = average_update<sqrt_rate, adaptive, normalized>(g, nd.pred_per_update);
453       nd.pred_per_update *= g.update_multiplier;
454     }
455 
456     return nd.pred_per_update;
457   }
458 
459 template<bool invariant, bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
compute_update(gd & g,example & ec)460 float compute_update(gd& g, example& ec)
461 {//invariant: not a test label, importance weight > 0
462   label_data& ld = ec.l.simple;
463   vw& all = *g.all;
464 
465   float ret = 0.;
466   ec.updated_prediction = ec.pred.scalar;
467   if (all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) > 0.)
468     {
469       float pred_per_update;
470       if(adaptive || normalized)
471 	pred_per_update = get_pred_per_update<sqrt_rate, feature_mask_off, adaptive, normalized, spare>(g,ec);
472       else
473 	pred_per_update = ec.total_sum_feat_sq;
474 
475       float delta_pred = pred_per_update * all.eta * ld.weight;
476       if(!adaptive)
477 	{
478 	  float t = (float)(ec.example_t - all.sd->weighted_holdout_examples);
479 	  delta_pred *= powf(t, g.neg_power_t);
480 	}
481 
482       float update;
483       if(invariant)
484 	update = all.loss->getUpdate(ec.pred.scalar, ld.label, delta_pred, pred_per_update);
485       else
486 	update = all.loss->getUnsafeUpdate(ec.pred.scalar, ld.label, delta_pred, pred_per_update);
487 
488       // changed from ec.partial_prediction to ld.prediction
489       ec.updated_prediction += pred_per_update * update;
490 
491       if (all.reg_mode && fabs(update) > 1e-8) {
492 	double dev1 = all.loss->first_derivative(all.sd, ec.pred.scalar, ld.label);
493 	double eta_bar = (fabs(dev1) > 1e-8) ? (-update / dev1) : 0.0;
494 	if (fabs(dev1) > 1e-8)
495 	  all.sd->contraction *= (1. - all.l2_lambda * eta_bar);
496 	update /= (float)all.sd->contraction;
497 	all.sd->gravity += eta_bar * all.l1_lambda;
498       }
499       ret = update;
500     }
501 
502   return ret;
503 }
504 
505 template<bool invariant, bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
update(gd & g,base_learner & base,example & ec)506 void update(gd& g, base_learner& base, example& ec)
507 {//invariant: not a test label, importance weight > 0
508   float update;
509   if ( (update = compute_update<invariant, sqrt_rate, feature_mask_off, adaptive, normalized, spare> (g, ec)) != 0.)
510     train<sqrt_rate, feature_mask_off, adaptive, normalized, spare>(g, ec, update);
511 
512   if (g.all->sd->contraction < 1e-10)  // updating weights now to avoid numerical instability
513     sync_weights(*g.all);
514 }
515 
516 template<bool invariant, bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
learn(gd & g,base_learner & base,example & ec)517 void learn(gd& g, base_learner& base, example& ec)
518 {//invariant: not a test label, importance weight > 0
519   assert(ec.in_use);
520   assert(ec.l.simple.label != FLT_MAX);
521   assert(ec.l.simple.weight > 0.);
522 
523   g.predict(g,base,ec);
524   update<invariant, sqrt_rate, feature_mask_off, adaptive, normalized, spare>(g,base,ec);
525 }
526 
sync_weights(vw & all)527 void sync_weights(vw& all) {
528   if (all.sd->gravity == 0. && all.sd->contraction == 1.)  // to avoid unnecessary weight synchronization
529     return;
530   uint32_t length = 1 << all.num_bits;
531   size_t stride = 1 << all.reg.stride_shift;
532   for(uint32_t i = 0; i < length && all.reg_mode; i++)
533     all.reg.weight_vector[stride*i] = trunc_weight(all.reg.weight_vector[stride*i], (float)all.sd->gravity) * (float)all.sd->contraction;
534   all.sd->gravity = 0.;
535   all.sd->contraction = 1.;
536 }
537 
save_load_regressor(vw & all,io_buf & model_file,bool read,bool text)538 void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text)
539 {
540   uint32_t length = 1 << all.num_bits;
541   uint32_t stride = 1 << all.reg.stride_shift;
542   int c = 0;
543   uint32_t i = 0;
544   size_t brw = 1;
545 
546   if(all.print_invert){ //write readable model with feature names
547     weight* v;
548     char buff[512];
549     int text_len;
550     typedef std::map< std::string, size_t> str_int_map;
551 
552     for(str_int_map::iterator it = all.name_index_map.begin(); it != all.name_index_map.end(); ++it){
553       v = &(all.reg.weight_vector[stride*(it->second)]);
554       if(*v != 0.){
555         text_len = sprintf(buff, "%s", (char*)it->first.c_str());
556         brw = bin_text_write_fixed(model_file, (char*)it->first.c_str(), sizeof(*it->first.c_str()),
557 					 buff, text_len, true);
558         text_len = sprintf(buff, ":%f\n", *v);
559         brw+= bin_text_write_fixed(model_file,(char *)v, sizeof (*v),
560 					 buff, text_len, true);
561       }
562     }
563     return;
564   }
565 
566   do
567     {
568       brw = 1;
569       weight* v;
570       if (read)
571 	{
572 	  c++;
573 	  brw = bin_read_fixed(model_file, (char*)&i, sizeof(i),"");
574 	  if (brw > 0)
575 	    {
576 	      assert (i< length);
577 	      v = &(all.reg.weight_vector[stride*i]);
578 	      brw += bin_read_fixed(model_file, (char*)v, sizeof(*v), "");
579 	    }
580 	}
581       else// write binary or text
582 	{
583 
584          v = &(all.reg.weight_vector[stride*i]);
585 	 if (*v != 0.)
586 	    {
587 	      c++;
588 	      char buff[512];
589 	      int text_len;
590 
591 	      text_len = sprintf(buff, "%d", i);
592 	      brw = bin_text_write_fixed(model_file,(char *)&i, sizeof (i),
593 					 buff, text_len, text);
594 
595               text_len = sprintf(buff, ":%f\n", *v);
596 	      brw+= bin_text_write_fixed(model_file,(char *)v, sizeof (*v),
597 					 buff, text_len, text);
598 	    }
599 	}
600 
601       if (!read)
602 	i++;
603     }
604   while ((!read && i < length) || (read && brw >0));
605 }
606 
607 //void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text)
save_load_online_state(vw & all,io_buf & model_file,bool read,bool text)608 void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text)
609 {
610   //vw& all = *g.all;
611 
612   char buff[512];
613 
614   uint32_t text_len = sprintf(buff, "initial_t %f\n", all.initial_t);
615   bin_text_read_write_fixed(model_file,(char*)&all.initial_t, sizeof(all.initial_t),
616 			    "", read,
617 			    buff, text_len, text);
618 
619   text_len = sprintf(buff, "norm normalizer %f\n", all.normalized_sum_norm_x);
620   bin_text_read_write_fixed(model_file,(char*)&all.normalized_sum_norm_x, sizeof(all.normalized_sum_norm_x),
621   			    "", read,
622   			    buff, text_len, text);
623 
624   text_len = sprintf(buff, "t %f\n", all.sd->t);
625   bin_text_read_write_fixed(model_file,(char*)&all.sd->t, sizeof(all.sd->t),
626 			    "", read,
627 			    buff, text_len, text);
628 
629   text_len = sprintf(buff, "sum_loss %f\n", all.sd->sum_loss);
630   bin_text_read_write_fixed(model_file,(char*)&all.sd->sum_loss, sizeof(all.sd->sum_loss),
631 			    "", read,
632 			    buff, text_len, text);
633 
634   text_len = sprintf(buff, "sum_loss_since_last_dump %f\n", all.sd->sum_loss_since_last_dump);
635   bin_text_read_write_fixed(model_file,(char*)&all.sd->sum_loss_since_last_dump, sizeof(all.sd->sum_loss_since_last_dump),
636 			    "", read,
637 			    buff, text_len, text);
638 
639   text_len = sprintf(buff, "dump_interval %f\n", all.sd->dump_interval);
640   bin_text_read_write_fixed(model_file,(char*)&all.sd->dump_interval, sizeof(all.sd->dump_interval),
641 			    "", read,
642 			    buff, text_len, text);
643 
644   text_len = sprintf(buff, "min_label %f\n", all.sd->min_label);
645   bin_text_read_write_fixed(model_file,(char*)&all.sd->min_label, sizeof(all.sd->min_label),
646 			    "", read,
647 			    buff, text_len, text);
648 
649   text_len = sprintf(buff, "max_label %f\n", all.sd->max_label);
650   bin_text_read_write_fixed(model_file,(char*)&all.sd->max_label, sizeof(all.sd->max_label),
651 			    "", read,
652 			    buff, text_len, text);
653 
654   text_len = sprintf(buff, "weighted_examples %f\n", all.sd->weighted_examples);
655   bin_text_read_write_fixed(model_file,(char*)&all.sd->weighted_examples, sizeof(all.sd->weighted_examples),
656 			    "", read,
657 			    buff, text_len, text);
658 
659   text_len = sprintf(buff, "weighted_labels %f\n", all.sd->weighted_labels);
660   bin_text_read_write_fixed(model_file,(char*)&all.sd->weighted_labels, sizeof(all.sd->weighted_labels),
661 			    "", read,
662 			    buff, text_len, text);
663 
664   text_len = sprintf(buff, "weighted_unlabeled_examples %f\n", all.sd->weighted_unlabeled_examples);
665   bin_text_read_write_fixed(model_file,(char*)&all.sd->weighted_unlabeled_examples, sizeof(all.sd->weighted_unlabeled_examples),
666 			    "", read,
667 			    buff, text_len, text);
668 
669   text_len = sprintf(buff, "example_number %u\n", (uint32_t)all.sd->example_number);
670   bin_text_read_write_fixed(model_file,(char*)&all.sd->example_number, sizeof(all.sd->example_number),
671 			    "", read,
672 			    buff, text_len, text);
673 
674   text_len = sprintf(buff, "total_features %u\n", (uint32_t)all.sd->total_features);
675   bin_text_read_write_fixed(model_file,(char*)&all.sd->total_features, sizeof(all.sd->total_features),
676 			    "", read,
677 			    buff, text_len, text);
678   if (!all.training) // reset various things so that we report test set performance properly
679     {
680       all.sd->sum_loss = 0;
681       all.sd->sum_loss_since_last_dump = 0;
682       all.sd->dump_interval = 1.;
683       all.sd->weighted_examples = 0.;
684       all.sd->weighted_labels = 0.;
685       all.sd->weighted_unlabeled_examples = 0.;
686       all.sd->example_number = 0;
687       all.sd->total_features = 0;
688     }
689 
690   uint32_t length = 1 << all.num_bits;
691   uint32_t stride = 1 << all.reg.stride_shift;
692   int c = 0;
693   uint32_t i = 0;
694   size_t brw = 1;
695   do
696     {
697       brw = 1;
698       weight* v;
699       if (read)
700 	{
701 	  c++;
702 	  brw = bin_read_fixed(model_file, (char*)&i, sizeof(i),"");
703 	  if (brw > 0)
704 	    {
705 	      assert (i< length);
706 	      v = &(all.reg.weight_vector[stride*i]);
707 	      if (stride == 2) //either adaptive or normalized
708 		brw += bin_read_fixed(model_file, (char*)v, sizeof(*v)*2, "");
709 	      else //adaptive and normalized
710 		brw += bin_read_fixed(model_file, (char*)v, sizeof(*v)*3, "");
711 	      if (!all.training)
712 		v[1]=v[2]=0.;
713 	    }
714 	}
715       else // write binary or text
716 	{
717 	  v = &(all.reg.weight_vector[stride*i]);
718 	  if (*v != 0.)
719 	    {
720 	      c++;
721 	      char buff[512];
722 	      int text_len = sprintf(buff, "%d", i);
723 	      brw = bin_text_write_fixed(model_file,(char *)&i, sizeof (i),
724 					 buff, text_len, text);
725 
726 	      if (stride == 2)
727 		{//either adaptive or normalized
728 		  text_len = sprintf(buff, ":%f %f\n", *v, *(v+1));
729 		  brw+= bin_text_write_fixed(model_file,(char *)v, 2*sizeof (*v),
730 					     buff, text_len, text);
731 		}
732 	      else
733 		{//adaptive and normalized
734 		  text_len = sprintf(buff, ":%f %f %f\n", *v, *(v+1), *(v+2));
735 		  brw+= bin_text_write_fixed(model_file,(char *)v, 3*sizeof (*v),
736 					     buff, text_len, text);
737 		}
738 	    }
739 	}
740       if (!read)
741 	i++;
742     }
743   while ((!read && i < length) || (read && brw >0));
744 }
745 
save_load(gd & g,io_buf & model_file,bool read,bool text)746 void save_load(gd& g, io_buf& model_file, bool read, bool text)
747 {
748   vw& all = *g.all;
749   if(read)
750     {
751       initialize_regressor(all);
752 
753       if(all.adaptive && all.initial_t > 0)
754 	{
755 	  uint32_t length = 1 << all.num_bits;
756 	  uint32_t stride = 1 << all.reg.stride_shift;
757 	  for (size_t j = 1; j < stride*length; j+=stride)
758 	    {
759 	      all.reg.weight_vector[j] = all.initial_t;   //for adaptive update, we interpret initial_t as previously seeing initial_t fake datapoints, all with squared gradient=1
760 	      //NOTE: this is not invariant to the scaling of the data (i.e. when combined with normalized). Since scaling the data scales the gradient, this should ideally be
761 	      //feature_range*initial_t, or something like that. We could potentially fix this by just adding this base quantity times the current range to the sum of gradients
762 	      //stored in memory at each update, and always start sum of gradients to 0, at the price of additional additions and multiplications during the update...
763 	    }
764 	}
765 
766       if (g.initial_constant != 0.0)
767         VW::set_weight(all, constant, 0, g.initial_constant);
768     }
769 
770   if (model_file.files.size() > 0)
771     {
772       bool resume = all.save_resume;
773       char buff[512];
774       uint32_t text_len = sprintf(buff, ":%d\n", resume);
775       bin_text_read_write_fixed(model_file,(char *)&resume, sizeof (resume),
776 				"", read,
777 				buff, text_len, text);
778       if (resume)
779 	//save_load_online_state(g, model_file, read, text);
780         save_load_online_state(all, model_file, read, text);
781       else
782 	save_load_regressor(all, model_file, read, text);
783     }
784 }
785 
786 template<bool invariant, bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next>
set_learn(vw & all,bool feature_mask_off,gd & g)787 uint32_t set_learn(vw& all, bool feature_mask_off, gd& g)
788 {
789   all.normalized_idx = normalized;
790   if (feature_mask_off)
791     {
792       g.learn = learn<invariant, sqrt_rate, true, adaptive, normalized, spare>;
793       g.update = update<invariant, sqrt_rate, true, adaptive, normalized, spare>;
794       return next;
795     }
796   else
797     {
798       g.learn = learn<invariant, sqrt_rate, false, adaptive, normalized, spare>;
799       g.update = update<invariant, sqrt_rate, false, adaptive, normalized, spare>;
800       return next;
801     }
802 }
803 
804 template<bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next>
set_learn(vw & all,bool feature_mask_off,gd & g)805 uint32_t set_learn(vw& all, bool feature_mask_off, gd& g)
806 {
807   if (all.invariant_updates)
808     return set_learn<true, sqrt_rate, adaptive, normalized, spare, next>(all, feature_mask_off, g);
809   else
810     return set_learn<false, sqrt_rate, adaptive, normalized, spare, next>(all, feature_mask_off, g);
811 }
812 
813 template<bool sqrt_rate, uint32_t adaptive, uint32_t spare>
set_learn(vw & all,bool feature_mask_off,gd & g)814 uint32_t set_learn(vw& all, bool feature_mask_off, gd& g)
815 {
816   // select the appropriate learn function based on adaptive, normalization, and feature mask
817   if (all.normalized_updates)
818     return set_learn<sqrt_rate, adaptive, adaptive+1, adaptive+2, adaptive+3>(all, feature_mask_off, g);
819   else
820     return set_learn<sqrt_rate, adaptive, 0, spare, spare+1>(all, feature_mask_off, g);
821 }
822 
823 template<bool sqrt_rate>
set_learn(vw & all,bool feature_mask_off,gd & g)824 uint32_t set_learn(vw& all, bool feature_mask_off, gd& g)
825 {
826   if (all.adaptive)
827     return set_learn<sqrt_rate, 1, 2>(all, feature_mask_off, g);
828   else
829     return set_learn<sqrt_rate, 0, 0>(all, feature_mask_off, g);
830 }
831 
ceil_log_2(uint32_t v)832 uint32_t ceil_log_2(uint32_t v)
833 {
834   if (v==0)
835     return 0;
836   else
837     return 1 + ceil_log_2(v >> 1);
838 }
839 
setup(vw & all)840 base_learner* setup(vw& all)
841 {
842   new_options(all, "Gradient Descent options")
843     ("sgd", "use regular stochastic gradient descent update.")
844     ("adaptive", "use adaptive, individual learning rates.")
845     ("invariant", "use safe/importance aware updates.")
846     ("normalized", "use per feature normalized updates")
847     ("exact_adaptive_norm", "use current default invariant normalized adaptive update rule");
848   add_options(all);
849   po::variables_map& vm = all.vm;
850   gd& g = calloc_or_die<gd>();
851   g.all = &all;
852   g.all->normalized_sum_norm_x = 0;
853   g.no_win_counter = 0;
854   g.total_weight = 0.;
855   g.early_stop_thres = 3;
856   g.neg_norm_power = (all.adaptive ? (all.power_t - 1.f) : -1.f);
857   g.neg_power_t = - all.power_t;
858 
859   if(all.initial_t > 0)//for the normalized update: if initial_t is bigger than 1 we interpret this as if we had seen (all.initial_t) previous fake datapoints all with norm 1
860     {
861       g.all->normalized_sum_norm_x = all.initial_t;
862       g.total_weight = all.initial_t;
863     }
864 
865   bool feature_mask_off = true;
866   if(vm.count("feature_mask"))
867     feature_mask_off = false;
868 
869   if(!all.holdout_set_off)
870   {
871     all.sd->holdout_best_loss = FLT_MAX;
872     if(vm.count("early_terminate"))
873       g.early_stop_thres = vm["early_terminate"].as< size_t>();
874   }
875 
876   if (vm.count("constant")) {
877       g.initial_constant = vm["constant"].as<float>();
878   }
879 
880   if( !all.training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) )
881     {//nondefault
882       all.adaptive = all.training && vm.count("adaptive");
883       all.invariant_updates = all.training && vm.count("invariant");
884       all.normalized_updates = all.training && vm.count("normalized");
885 
886       if(!vm.count("learning_rate") && !vm.count("l") && !(all.adaptive && all.normalized_updates))
887 	all.eta = 10; //default learning rate to 10 for non default update rule
888 
889       //if not using normalized or adaptive, default initial_t to 1 instead of 0
890       if(!all.adaptive && !all.normalized_updates){
891 	if (!vm.count("initial_t")) {
892 	  all.sd->t = 1.f;
893 	  all.sd->weighted_unlabeled_examples = 1.f;
894 	  all.initial_t = 1.f;
895 	}
896 	all.eta *= powf((float)(all.sd->t), all.power_t);
897       }
898     }
899 
900   if (pow((double)all.eta_decay_rate, (double)all.numpasses) < 0.0001 )
901     cerr << "Warning: the learning rate for the last pass is multiplied by: " << pow((double)all.eta_decay_rate, (double)all.numpasses)
902 	 << " adjust --decay_learning_rate larger to avoid this." << endl;
903 
904   if (all.reg_mode % 2)
905     if (all.audit || all.hash_inv)
906       g.predict = predict<true, true>;
907     else
908       g.predict = predict<true, false>;
909   else if (all.audit || all.hash_inv)
910     g.predict = predict<false, true>;
911   else
912     g.predict = predict<false, false>;
913 
914   uint32_t stride;
915   if (all.power_t == 0.5)
916     stride = set_learn<true>(all, feature_mask_off, g);
917   else
918     stride = set_learn<false>(all, feature_mask_off, g);
919   all.reg.stride_shift = ceil_log_2(stride-1);
920 
921   learner<gd>& ret = init_learner(&g, g.learn, ((uint64_t)1 << all.reg.stride_shift));
922   ret.set_predict(g.predict);
923   ret.set_update(g.update);
924   ret.set_save_load(save_load);
925   ret.set_end_pass(end_pass);
926   return make_base(ret);
927 }
928 }
929