1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 /*
7 The algorithm here is generally based on Nocedal 1980, Liu and Nocedal 1989.
8 Implementation by Miro Dudik.
9  */
10 #include <fstream>
11 #include <float.h>
12 #include <exception>
13 #ifndef _WIN32
14 #include <netdb.h>
15 #endif
16 #include <string.h>
17 #include <stdio.h>
18 #include <assert.h>
19 #include <sys/timeb.h>
20 #include "accumulate.h"
21 #include "gd.h"
22 
23 using namespace std;
24 using namespace LEARNER;
25 
26 #define CG_EXTRA 1
27 
28 #define MEM_GT 0
29 #define MEM_XT 1
30 #define MEM_YT 0
31 #define MEM_ST 1
32 
33 #define W_XT 0
34 #define W_GT 1
35 #define W_DIR 2
36 #define W_COND 3
37 
38 #define LEARN_OK 0
39 #define LEARN_CURV 1
40 #define LEARN_CONV 2
41 
42 class curv_exception: public exception {} curv_ex;
43 
44 /********************************************************************/
45 /* mem & w definition ***********************************************/
46 /********************************************************************/
47 // mem[2*i] = y_t
48 // mem[2*i+1] = s_t
49 //
50 // w[0] = weight
51 // w[1] = accumulated first derivative
52 // w[2] = step direction
53 // w[3] = preconditioner
54 
55   const float max_precond_ratio = 10000.f;
56 
57   struct bfgs {
58     vw* all;//prediction, regressor
59     int m;
60     float rel_threshold; // termination threshold
61 
62     double wolfe1_bound;
63 
64     size_t final_pass;
65     struct timeb t_start, t_end;
66     double net_comm_time;
67 
68     struct timeb t_start_global, t_end_global;
69     double net_time;
70 
71     v_array<float> predictions;
72     size_t example_number;
73     size_t current_pass;
74     size_t no_win_counter;
75     size_t early_stop_thres;
76 
77     // default transition behavior
78     bool first_hessian_on;
79     bool backstep_on;
80 
81     // set by initializer
82     int mem_stride;
83     bool output_regularizer;
84     float* mem;
85     double* rho;
86     double* alpha;
87 
88     weight* regularizers;
89     // the below needs to be included when resetting, in addition to preconditioner and derivative
90     int lastj, origin;
91     double loss_sum, previous_loss_sum;
92     float step_size;
93     double importance_weight_sum;
94     double curvature;
95 
96     // first pass specification
97     bool first_pass;
98     bool gradient_pass;
99     bool preconditioner_pass;
100   };
101 
102 const char* curv_message = "Zero or negative curvature detected.\n"
103       "To increase curvature you can increase regularization or rescale features.\n"
104       "It is also possible that you have reached numerical accuracy\n"
105       "and further decrease in the objective cannot be reliably detected.\n";
106 
zero_derivative(vw & all)107 void zero_derivative(vw& all)
108 {//set derivative to 0.
109   uint32_t length = 1 << all.num_bits;
110   size_t stride_shift = all.reg.stride_shift;
111   weight* weights = all.reg.weight_vector;
112   for(uint32_t i = 0; i < length; i++)
113     weights[(i << stride_shift) +W_GT] = 0;
114 }
115 
zero_preconditioner(vw & all)116 void zero_preconditioner(vw& all)
117 {//set derivative to 0.
118   uint32_t length = 1 << all.num_bits;
119   size_t stride_shift = all.reg.stride_shift;
120   weight* weights = all.reg.weight_vector;
121   for(uint32_t i = 0; i < length; i++)
122     weights[(i << stride_shift)+W_COND] = 0;
123 }
124 
reset_state(vw & all,bfgs & b,bool zero)125 void reset_state(vw& all, bfgs& b, bool zero)
126 {
127   b.lastj = b.origin = 0;
128   b.loss_sum = b.previous_loss_sum = 0.;
129   b.importance_weight_sum = 0.;
130   b.curvature = 0.;
131   b.first_pass = true;
132   b.gradient_pass = true;
133   b.preconditioner_pass = true;
134   if (zero)
135     {
136       zero_derivative(all);
137       zero_preconditioner(all);
138     }
139 }
140 
141 // w[0] = weight
142 // w[1] = accumulated first derivative
143 // w[2] = step direction
144 // w[3] = preconditioner
145 
test_example(example & ec)146 bool test_example(example& ec)
147 {
148   return ec.l.simple.label == FLT_MAX;
149 }
150 
bfgs_predict(vw & all,example & ec)151   float bfgs_predict(vw& all, example& ec)
152   {
153     ec.partial_prediction = GD::inline_predict(all,ec);
154     return GD::finalize_prediction(all.sd, ec.partial_prediction);
155   }
156 
add_grad(float & d,float f,float & fw)157 inline void add_grad(float& d, float f, float& fw)
158 {
159   fw += d * f;
160 }
161 
predict_and_gradient(vw & all,example & ec)162 float predict_and_gradient(vw& all, example &ec)
163 {
164   float fp = bfgs_predict(all, ec);
165 
166   label_data& ld = ec.l.simple;
167   all.set_minmax(all.sd, ld.label);
168 
169   float loss_grad = all.loss->first_derivative(all.sd, fp,ld.label)*ld.weight;
170 
171   ec.ft_offset += W_GT;
172   GD::foreach_feature<float,add_grad>(all, ec, loss_grad);
173   ec.ft_offset -= W_GT;
174 
175   return fp;
176 }
177 
add_precond(float & d,float f,float & fw)178 inline void add_precond(float& d, float f, float& fw)
179 {
180   fw += d * f * f;
181 }
182 
update_preconditioner(vw & all,example & ec)183 void update_preconditioner(vw& all, example& ec)
184 {
185   label_data& ld = ec.l.simple;
186   float curvature = all.loss->second_derivative(all.sd, ec.pred.scalar, ld.label) * ld.weight;
187 
188   ec.ft_offset += W_COND;
189   GD::foreach_feature<float,add_precond>(all, ec, curvature);
190   ec.ft_offset -= W_COND;
191 }
192 
193 
dot_with_direction(vw & all,example & ec)194 float dot_with_direction(vw& all, example& ec)
195 {
196   ec.ft_offset+= W_DIR;
197   float ret = GD::inline_predict(all, ec);
198   ec.ft_offset-= W_DIR;
199 
200   return ret;
201 }
202 
regularizer_direction_magnitude(vw & all,bfgs & b,float regularizer)203 double regularizer_direction_magnitude(vw& all, bfgs& b, float regularizer)
204 {//compute direction magnitude
205   double ret = 0.;
206 
207   if (regularizer == 0.)
208     return ret;
209 
210   uint32_t length = 1 << all.num_bits;
211   size_t stride_shift = all.reg.stride_shift;
212   weight* weights = all.reg.weight_vector;
213   if (b.regularizers == NULL)
214     for(uint32_t i = 0; i < length; i++)
215       ret += regularizer*weights[(i << stride_shift)+W_DIR]*weights[(i << stride_shift)+W_DIR];
216   else
217     for(uint32_t i = 0; i < length; i++)
218       ret += b.regularizers[2*i]*weights[(i << stride_shift)+W_DIR]*weights[(i << stride_shift)+W_DIR];
219 
220   return ret;
221 }
222 
direction_magnitude(vw & all)223 float direction_magnitude(vw& all)
224 {//compute direction magnitude
225   double ret = 0.;
226   uint32_t length = 1 << all.num_bits;
227   size_t stride_shift = all.reg.stride_shift;
228   weight* weights = all.reg.weight_vector;
229   for(uint32_t i = 0; i < length; i++)
230     ret += weights[(i << stride_shift)+W_DIR]*weights[(i << stride_shift)+W_DIR];
231 
232   return (float)ret;
233 }
234 
bfgs_iter_start(vw & all,bfgs & b,float * mem,int & lastj,double importance_weight_sum,int & origin)235 void bfgs_iter_start(vw& all, bfgs& b, float* mem, int& lastj, double importance_weight_sum, int&origin)
236 {
237   uint32_t length = 1 << all.num_bits;
238   size_t stride = 1 << all.reg.stride_shift;
239   weight* w = all.reg.weight_vector;
240 
241   double g1_Hg1 = 0.;
242   double g1_g1 = 0.;
243 
244   origin = 0;
245   for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
246     if (b.m>0)
247       mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT];
248     mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
249     g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND];
250     g1_g1 += w[W_GT] * w[W_GT];
251     w[W_DIR] = -w[W_COND]*w[W_GT];
252     w[W_GT] = 0;
253   }
254   lastj = 0;
255   if (!all.quiet)
256     fprintf(stderr, "%-10.5f\t%-10.5f\t%-10s\t%-10s\t%-10s\t",
257 	    g1_g1/(importance_weight_sum*importance_weight_sum),
258 	    g1_Hg1/importance_weight_sum, "", "", "");
259 }
260 
bfgs_iter_middle(vw & all,bfgs & b,float * mem,double * rho,double * alpha,int & lastj,int & origin)261 void bfgs_iter_middle(vw& all, bfgs& b, float* mem, double* rho, double* alpha, int& lastj, int &origin)
262 {
263   uint32_t length = 1 << all.num_bits;
264   size_t stride = 1 << all.reg.stride_shift;
265   weight* w = all.reg.weight_vector;
266 
267   float* mem0 = mem;
268   float* w0 = w;
269 
270   // implement conjugate gradient
271   if (b.m==0) {
272     double g_Hy = 0.;
273     double g_Hg = 0.;
274     double y = 0.;
275 
276     for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
277       y = w[W_GT]-mem[(MEM_GT+origin)%b.mem_stride];
278       g_Hy += w[W_GT] * w[W_COND] * y;
279       g_Hg += mem[(MEM_GT+origin)%b.mem_stride] * w[W_COND] * mem[(MEM_GT+origin)%b.mem_stride];
280     }
281 
282     float beta = (float) (g_Hy/g_Hg);
283 
284     if (beta<0.f || nanpattern(beta))
285       beta = 0.f;
286 
287     mem = mem0;
288     w = w0;
289     for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
290       mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
291 
292       w[W_DIR] *= beta;
293       w[W_DIR] -= w[W_COND]*w[W_GT];
294       w[W_GT] = 0;
295     }
296     if (!all.quiet)
297       fprintf(stderr, "%f\t", beta);
298     return;
299   }
300   else {
301     if (!all.quiet)
302       fprintf(stderr, "%-10s\t","");
303   }
304 
305   // implement bfgs
306   double y_s = 0.;
307   double y_Hy = 0.;
308   double s_q = 0.;
309 
310   for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
311     mem[(MEM_YT+origin)%b.mem_stride] = w[W_GT] - mem[(MEM_GT+origin)%b.mem_stride];
312     mem[(MEM_ST+origin)%b.mem_stride] = w[W_XT] - mem[(MEM_XT+origin)%b.mem_stride];
313     w[W_DIR] = w[W_GT];
314     y_s += mem[(MEM_YT+origin)%b.mem_stride]*mem[(MEM_ST+origin)%b.mem_stride];
315     y_Hy += mem[(MEM_YT+origin)%b.mem_stride]*mem[(MEM_YT+origin)%b.mem_stride]*w[W_COND];
316     s_q += mem[(MEM_ST+origin)%b.mem_stride]*w[W_GT];
317   }
318 
319   if (y_s <= 0. || y_Hy <= 0.)
320     throw curv_ex;
321   rho[0] = 1/y_s;
322 
323   float gamma = (float) (y_s/y_Hy);
324 
325   for (int j=0; j<lastj; j++) {
326     alpha[j] = rho[j] * s_q;
327     s_q = 0.;
328     mem = mem0;
329     w = w0;
330     for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
331       w[W_DIR] -= (float)alpha[j]*mem[(2*j+MEM_YT+origin)%b.mem_stride];
332       s_q += mem[(2*j+2+MEM_ST+origin)%b.mem_stride]*w[W_DIR];
333     }
334   }
335 
336   alpha[lastj] = rho[lastj] * s_q;
337   double y_r = 0.;
338   mem = mem0;
339   w = w0;
340   for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
341     w[W_DIR] -= (float)alpha[lastj]*mem[(2*lastj+MEM_YT+origin)%b.mem_stride];
342     w[W_DIR] *= gamma*w[W_COND];
343     y_r += mem[(2*lastj+MEM_YT+origin)%b.mem_stride]*w[W_DIR];
344   }
345 
346   double coef_j;
347 
348   for (int j=lastj; j>0; j--) {
349     coef_j = alpha[j] - rho[j] * y_r;
350     y_r = 0.;
351     mem = mem0;
352     w = w0;
353     for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
354       w[W_DIR] += (float)coef_j*mem[(2*j+MEM_ST+origin)%b.mem_stride];
355       y_r += mem[(2*j-2+MEM_YT+origin)%b.mem_stride]*w[W_DIR];
356     }
357   }
358 
359 
360   coef_j = alpha[0] - rho[0] * y_r;
361   mem = mem0;
362   w = w0;
363   for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
364     w[W_DIR] = -w[W_DIR]-(float)coef_j*mem[(MEM_ST+origin)%b.mem_stride];
365   }
366 
367   /*********************
368    ** shift
369    ********************/
370 
371   mem = mem0;
372   w = w0;
373   lastj = (lastj<b.m-1) ? lastj+1 : b.m-1;
374   origin = (origin+b.mem_stride-2)%b.mem_stride;
375   for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
376     mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
377     mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT];
378     w[W_GT] = 0;
379   }
380   for (int j=lastj; j>0; j--)
381     rho[j] = rho[j-1];
382 }
383 
wolfe_eval(vw & all,bfgs & b,float * mem,double loss_sum,double previous_loss_sum,double step_size,double importance_weight_sum,int & origin,double & wolfe1)384 double wolfe_eval(vw& all, bfgs& b, float* mem, double loss_sum, double previous_loss_sum, double step_size, double importance_weight_sum, int &origin, double& wolfe1) {
385   uint32_t length = 1 << all.num_bits;
386   size_t stride = 1 << all.reg.stride_shift;
387   weight* w = all.reg.weight_vector;
388 
389   double g0_d = 0.;
390   double g1_d = 0.;
391   double g1_Hg1 = 0.;
392   double g1_g1 = 0.;
393 
394   for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
395     g0_d += mem[(MEM_GT+origin)%b.mem_stride] * w[W_DIR];
396     g1_d += w[W_GT] * w[W_DIR];
397     g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND];
398     g1_g1 += w[W_GT] * w[W_GT];
399   }
400 
401   wolfe1 = (loss_sum-previous_loss_sum)/(step_size*g0_d);
402   double wolfe2 = g1_d/g0_d;
403   // double new_step_cross = (loss_sum-previous_loss_sum-g1_d*step)/(g0_d-g1_d);
404 
405   if (!all.quiet)
406     fprintf(stderr, "%-10.5f\t%-10.5f\t%s%-10f\t%-10f\t", g1_g1/(importance_weight_sum*importance_weight_sum), g1_Hg1/importance_weight_sum, " ", wolfe1, wolfe2);
407   return 0.5*step_size;
408 }
409 
410 
add_regularization(vw & all,bfgs & b,float regularization)411 double add_regularization(vw& all, bfgs& b, float regularization)
412 {//compute the derivative difference
413   double ret = 0.;
414   uint32_t length = 1 << all.num_bits;
415   size_t stride_shift = all.reg.stride_shift;
416   weight* weights = all.reg.weight_vector;
417   if (b.regularizers == NULL)
418     {
419       for(uint32_t i = 0; i < length; i++) {
420 	weights[(i << stride_shift)+W_GT] += regularization*weights[i << stride_shift];
421 	ret += 0.5*regularization*weights[i << stride_shift]*weights[i << stride_shift];
422       }
423     }
424   else
425     {
426       for(uint32_t i = 0; i < length; i++) {
427 	weight delta_weight = weights[i << stride_shift] - b.regularizers[2*i+1];
428 	weights[(i << stride_shift)+W_GT] += b.regularizers[2*i]*delta_weight;
429 	ret += 0.5*b.regularizers[2*i]*delta_weight*delta_weight;
430       }
431     }
432 
433   return ret;
434 }
435 
finalize_preconditioner(vw & all,bfgs & b,float regularization)436 void finalize_preconditioner(vw& all, bfgs& b, float regularization)
437 {
438   uint32_t length = 1 << all.num_bits;
439   size_t stride = 1 << all.reg.stride_shift;
440   weight* weights = all.reg.weight_vector;
441   float max_hessian = 0.f;
442 
443   if (b.regularizers == NULL)
444     for(uint32_t i = 0; i < length; i++) {
445       weights[stride*i+W_COND] += regularization;
446 	  if (weights[stride*i+W_COND] > max_hessian)
447 		  max_hessian = weights[stride*i+W_COND];
448       if (weights[stride*i+W_COND] > 0)
449 	weights[stride*i+W_COND] = 1.f / weights[stride*i+W_COND];
450     }
451   else
452     for(uint32_t i = 0; i < length; i++) {
453       weights[stride*i+W_COND] += b.regularizers[2*i];
454 	  if (weights[stride*i+W_COND] > max_hessian)
455 		  max_hessian = weights[stride*i+W_COND];
456       if (weights[stride*i+W_COND] > 0)
457 	weights[stride*i+W_COND] = 1.f / weights[stride*i+W_COND];
458     }
459 
460   float max_precond = (max_hessian==0.f) ? 0.f : max_precond_ratio / max_hessian;
461   weights = all.reg.weight_vector;
462   for(uint32_t i = 0; i < length; i++) {
463     if (infpattern(weights[stride*i+W_COND]) || weights[stride*i+W_COND]>max_precond)
464 			weights[stride*i+W_COND] = max_precond;
465   }
466 }
467 
preconditioner_to_regularizer(vw & all,bfgs & b,float regularization)468 void preconditioner_to_regularizer(vw& all, bfgs& b, float regularization)
469 {
470   uint32_t length = 1 << all.num_bits;
471   size_t stride = 1 << all.reg.stride_shift;
472   weight* weights = all.reg.weight_vector;
473   if (b.regularizers == NULL)
474     {
475       b.regularizers = calloc_or_die<weight>(2*length);
476 
477       if (b.regularizers == NULL)
478 	{
479 	  cerr << all.program_name << ": Failed to allocate weight array: try decreasing -b <bits>" << endl;
480 	  throw exception();
481 	}
482       for(uint32_t i = 0; i < length; i++)
483 	b.regularizers[2*i] = weights[stride*i+W_COND] + regularization;
484     }
485   else
486     for(uint32_t i = 0; i < length; i++)
487       b.regularizers[2*i] = weights[stride*i+W_COND] + b.regularizers[2*i];
488   for(uint32_t i = 0; i < length; i++)
489       b.regularizers[2*i+1] = weights[stride*i];
490 }
491 
zero_state(vw & all)492 void zero_state(vw& all)
493 {
494   uint32_t length = 1 << all.num_bits;
495   size_t stride = 1 << all.reg.stride_shift;
496   weight* weights = all.reg.weight_vector;
497   for(uint32_t i = 0; i < length; i++)
498     {
499       weights[stride*i+W_GT] = 0;
500       weights[stride*i+W_DIR] = 0;
501       weights[stride*i+W_COND] = 0;
502     }
503 }
504 
derivative_in_direction(vw & all,bfgs & b,float * mem,int & origin)505 double derivative_in_direction(vw& all, bfgs& b, float* mem, int &origin)
506   {
507   double ret = 0.;
508   uint32_t length = 1 << all.num_bits;
509   size_t stride = 1 << all.reg.stride_shift;
510   weight* w = all.reg.weight_vector;
511 
512   for(uint32_t i = 0; i < length; i++, w+=stride, mem+=b.mem_stride)
513     ret += mem[(MEM_GT+origin)%b.mem_stride]*w[W_DIR];
514   return ret;
515 }
516 
update_weight(vw & all,float step_size,size_t current_pass)517 void update_weight(vw& all, float step_size, size_t current_pass)
518   {
519     uint32_t length = 1 << all.num_bits;
520     size_t stride = 1 << all.reg.stride_shift;
521     weight* w = all.reg.weight_vector;
522 
523     for(uint32_t i = 0; i < length; i++, w+=stride)
524       w[W_XT] += step_size * w[W_DIR];
525   }
526 
process_pass(vw & all,bfgs & b)527 int process_pass(vw& all, bfgs& b) {
528   int status = LEARN_OK;
529 
530   /********************************************************************/
531   /* A) FIRST PASS FINISHED: INITIALIZE FIRST LINE SEARCH *************/
532   /********************************************************************/
533     if (b.first_pass) {
534       if(all.span_server != "")
535 	{
536 	  accumulate(all, all.span_server, all.reg, W_COND); //Accumulate preconditioner
537 	  float temp = (float)b.importance_weight_sum;
538 	  b.importance_weight_sum = accumulate_scalar(all, all.span_server, temp);
539 	}
540       finalize_preconditioner(all, b, all.l2_lambda);
541       if(all.span_server != "") {
542 	float temp = (float)b.loss_sum;
543 	b.loss_sum = accumulate_scalar(all, all.span_server, temp);  //Accumulate loss_sums
544 	accumulate(all, all.span_server, all.reg, 1); //Accumulate gradients from all nodes
545       }
546       if (all.l2_lambda > 0.)
547 	b.loss_sum += add_regularization(all, b, all.l2_lambda);
548       if (!all.quiet)
549 	fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)b.current_pass+1, b.loss_sum / b.importance_weight_sum);
550 
551       b.previous_loss_sum = b.loss_sum;
552       b.loss_sum = 0.;
553       b.example_number = 0;
554       b.curvature = 0;
555       bfgs_iter_start(all, b, b.mem, b.lastj, b.importance_weight_sum, b.origin);
556       if (b.first_hessian_on) {
557 	b.gradient_pass = false;//now start computing curvature
558       }
559       else {
560 	b.step_size = 0.5;
561 	float d_mag = direction_magnitude(all);
562 	ftime(&b.t_end_global);
563 	b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
564 	if (!all.quiet)
565 	  fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, b.step_size);
566 	b.predictions.erase();
567 	update_weight(all, b.step_size, b.current_pass);		     		           }
568     }
569     else
570   /********************************************************************/
571   /* B) GRADIENT CALCULATED *******************************************/
572   /********************************************************************/
573 	      if (b.gradient_pass) // We just finished computing all gradients
574 		{
575 		  if(all.span_server != "") {
576 		    float t = (float)b.loss_sum;
577 		    b.loss_sum = accumulate_scalar(all, all.span_server, t);  //Accumulate loss_sums
578 		    accumulate(all, all.span_server, all.reg, 1); //Accumulate gradients from all nodes
579 		  }
580 		  if (all.l2_lambda > 0.)
581 		    b.loss_sum += add_regularization(all, b, all.l2_lambda);
582 		  if (!all.quiet){
583                     if(!all.holdout_set_off && b.current_pass >= 1){
584                       if(all.sd->holdout_sum_loss_since_last_pass == 0. && all.sd->weighted_holdout_examples_since_last_pass == 0.){
585                         fprintf(stderr, "%2lu ", (long unsigned int)b.current_pass+1);
586                         fprintf(stderr, "h unknown    ");
587                       }
588                       else
589                         fprintf(stderr, "%2lu h%-10.5f\t", (long unsigned int)b.current_pass+1, all.sd->holdout_sum_loss_since_last_pass / all.sd->weighted_holdout_examples_since_last_pass);
590                     }
591                     else
592                       fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)b.current_pass+1, b.loss_sum / b.importance_weight_sum);
593                   }
594 		  double wolfe1;
595 		  double new_step = wolfe_eval(all, b, b.mem, b.loss_sum, b.previous_loss_sum, b.step_size, b.importance_weight_sum, b.origin, wolfe1);
596 
597   /********************************************************************/
598   /* B0) DERIVATIVE ZERO: MINIMUM FOUND *******************************/
599   /********************************************************************/
600 		  if (nanpattern((float)wolfe1))
601 		    {
602 		      fprintf(stderr, "\n");
603 		      fprintf(stdout, "Derivative 0 detected.\n");
604 		      b.step_size=0.0;
605 		      status = LEARN_CONV;
606 		    }
607   /********************************************************************/
608   /* B1) LINE SEARCH FAILED *******************************************/
609   /********************************************************************/
610 		  else if (b.backstep_on && (wolfe1<b.wolfe1_bound || b.loss_sum > b.previous_loss_sum))
611 		    {// curvature violated, or we stepped too far last time: step back
612 		      ftime(&b.t_end_global);
613 		      b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
614 		      float ratio = (b.step_size==0.f) ? 0.f : (float)new_step/(float)b.step_size;
615 		      if (!all.quiet)
616 			fprintf(stderr, "%-10s\t%-10s\t(revise x %.1f)\t%-10.5f\n",
617 				"","",ratio,
618 				new_step);
619 			b.predictions.erase();
620 			update_weight(all, (float)(-b.step_size+new_step), b.current_pass);
621 			b.step_size = (float)new_step;
622 			zero_derivative(all);
623 			b.loss_sum = 0.;
624 		    }
625 
626   /********************************************************************/
627   /* B2) LINE SEARCH SUCCESSFUL OR DISABLED          ******************/
628   /*     DETERMINE NEXT SEARCH DIRECTION             ******************/
629   /********************************************************************/
630 		  else {
631 		      double rel_decrease = (b.previous_loss_sum-b.loss_sum)/b.previous_loss_sum;
632 		      if (!nanpattern((float)rel_decrease) && b.backstep_on && fabs(rel_decrease)<b.rel_threshold) {
633 			fprintf(stdout, "\nTermination condition reached in pass %ld: decrease in loss less than %.3f%%.\n"
634 				"If you want to optimize further, decrease termination threshold.\n", (long int)b.current_pass+1, b.rel_threshold*100.0);
635 			status = LEARN_CONV;
636 		      }
637 		      b.previous_loss_sum = b.loss_sum;
638 		      b.loss_sum = 0.;
639 		      b.example_number = 0;
640 		      b.curvature = 0;
641 		      b.step_size = 1.0;
642 
643 		      try {
644 			bfgs_iter_middle(all, b, b.mem, b.rho, b.alpha, b.lastj, b.origin);
645 		      }
646 		      catch (curv_exception e) {
647 			fprintf(stdout, "In bfgs_iter_middle: %s", curv_message);
648 			b.step_size=0.0;
649 			status = LEARN_CURV;
650 		      }
651 
652 		      if (all.hessian_on) {
653 			b.gradient_pass = false;//now start computing curvature
654 		      }
655 		      else {
656 			float d_mag = direction_magnitude(all);
657 			ftime(&b.t_end_global);
658 			b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
659 			if (!all.quiet)
660 			  fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, b.step_size);
661 			b.predictions.erase();
662 			update_weight(all, b.step_size, b.current_pass);
663 		      }
664 		    }
665 		}
666 
667   /********************************************************************/
668   /* C) NOT FIRST PASS, CURVATURE CALCULATED **************************/
669   /********************************************************************/
670 	      else // just finished all second gradients
671 		{
672 		  if(all.span_server != "") {
673 		    float t = (float)b.curvature;
674 		    b.curvature = accumulate_scalar(all, all.span_server, t);  //Accumulate curvatures
675 		  }
676 		  if (all.l2_lambda > 0.)
677 		    b.curvature += regularizer_direction_magnitude(all, b, all.l2_lambda);
678 		  float dd = (float)derivative_in_direction(all, b, b.mem, b.origin);
679 		  if (b.curvature == 0. && dd != 0.)
680 		    {
681 		      fprintf(stdout, "%s", curv_message);
682 		      b.step_size=0.0;
683 		      status = LEARN_CURV;
684 		    }
685 		  else if ( dd == 0.)
686 		    {
687 		      fprintf(stdout, "Derivative 0 detected.\n");
688 		      b.step_size=0.0;
689 		      status = LEARN_CONV;
690 		    }
691 		  else
692 		    b.step_size = - dd/(float)b.curvature;
693 
694 		  float d_mag = direction_magnitude(all);
695 
696 		  b.predictions.erase();
697 		  update_weight(all, b.step_size, b.current_pass);
698 		  ftime(&b.t_end_global);
699 		  b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
700 		  if (!all.quiet)
701 		    fprintf(stderr, "%-10.5f\t%-10.5f\t%-10.5f\n", b.curvature / b.importance_weight_sum, d_mag, b.step_size);
702 		  b.gradient_pass = true;
703 		}//now start computing derivatives.
704     b.current_pass++;
705     b.first_pass = false;
706     b.preconditioner_pass = false;
707 
708     if (b.output_regularizer)//need to accumulate and place the regularizer.
709       {
710 	if(all.span_server != "")
711 	  accumulate(all, all.span_server, all.reg, W_COND); //Accumulate preconditioner
712 	//preconditioner_to_regularizer(all, b, all.l2_lambda);
713       }
714     ftime(&b.t_end_global);
715     b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
716 
717     if (all.save_per_pass)
718       save_predictor(all, all.final_regressor_name, b.current_pass);
719     return status;
720 }
721 
process_example(vw & all,bfgs & b,example & ec)722 void process_example(vw& all, bfgs& b, example& ec)
723  {
724   label_data& ld = ec.l.simple;
725   if (b.first_pass)
726     b.importance_weight_sum += ld.weight;
727 
728   /********************************************************************/
729   /* I) GRADIENT CALCULATION ******************************************/
730   /********************************************************************/
731   if (b.gradient_pass)
732     {
733       ec.pred.scalar = predict_and_gradient(all, ec);//w[0] & w[1]
734       ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight;
735       b.loss_sum += ec.loss;
736       b.predictions.push_back(ec.pred.scalar);
737     }
738   /********************************************************************/
739   /* II) CURVATURE CALCULATION ****************************************/
740   /********************************************************************/
741   else //computing curvature
742     {
743       float d_dot_x = dot_with_direction(all, ec);//w[2]
744       if (b.example_number >= b.predictions.size())//Make things safe in case example source is strange.
745 	b.example_number = b.predictions.size()-1;
746       ec.pred.scalar = b.predictions[b.example_number];
747       ec.partial_prediction = b.predictions[b.example_number];
748       ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight;
749       float sd = all.loss->second_derivative(all.sd, b.predictions[b.example_number++],ld.label);
750       b.curvature += d_dot_x*d_dot_x*sd*ld.weight;
751     }
752   ec.updated_prediction = ec.pred.scalar;
753 
754   if (b.preconditioner_pass)
755     update_preconditioner(all, ec);//w[3]
756  }
757 
end_pass(bfgs & b)758 void end_pass(bfgs& b)
759 {
760   vw* all = b.all;
761 
762   if (b.current_pass <= b.final_pass)
763   {
764        if(b.current_pass < b.final_pass)
765        {
766           int status = process_pass(*all, b);
767 
768           //reaching the max number of passes regardless of convergence
769           if(b.final_pass == b.current_pass)
770           {
771              cerr<<"Maximum number of passes reached. ";
772              if(!b.output_regularizer)
773                 cerr<<"If you want to optimize further, increase the number of passes\n";
774              if(b.output_regularizer)
775              {
776                cerr<<"\nRegular model file has been created. ";
777                cerr<<"Output feature regularizer file is created only when the convergence is reached. Try increasing the number of passes for convergence\n";
778                b.output_regularizer = false;
779              }
780 
781           }
782 
783           //attain convergence before reaching max iterations
784 	   if (status != LEARN_OK && b.final_pass > b.current_pass) {
785 	      b.final_pass = b.current_pass;
786 	   }
787 
788 	   if (b.output_regularizer && b.final_pass == b.current_pass) {
789 	     zero_preconditioner(*all);
790 	     b.preconditioner_pass = true;
791 	   }
792 
793 	   if(!all->holdout_set_off)
794 	   {
795 	     if(summarize_holdout_set(*all, b.no_win_counter))
796                finalize_regressor(*all, all->final_regressor_name);
797 	     if(b.early_stop_thres == b.no_win_counter)
798 	     {
799                set_done(*all);
800                cerr<<"Early termination reached w.r.t. holdout set error";
801              }
802 	   } if (b.final_pass == b.current_pass) {
803 	     finalize_regressor(*all, all->final_regressor_name);
804 	     set_done(*all);
805 	   }
806 
807        }else{//reaching convergence in the previous pass
808         if(b.output_regularizer)
809            preconditioner_to_regularizer(*all, b, (*all).l2_lambda);
810         b.current_pass ++;
811       }
812 
813   }
814 }
815 
816 // placeholder
predict(bfgs & b,base_learner & base,example & ec)817 void predict(bfgs& b, base_learner& base, example& ec)
818 {
819   vw* all = b.all;
820   ec.pred.scalar = bfgs_predict(*all,ec);
821 }
822 
learn(bfgs & b,base_learner & base,example & ec)823 void learn(bfgs& b, base_learner& base, example& ec)
824 {
825   vw* all = b.all;
826   assert(ec.in_use);
827 
828   if (b.current_pass <= b.final_pass)
829     {
830       if (test_example(ec))
831 	predict(b, base, ec);
832       else
833 	process_example(*all, b, ec);
834     }
835 }
836 
finish(bfgs & b)837 void finish(bfgs& b)
838 {
839   b.predictions.delete_v();
840   free(b.mem);
841   free(b.rho);
842   free(b.alpha);
843 }
844 
save_load_regularizer(vw & all,bfgs & b,io_buf & model_file,bool read,bool text)845 void save_load_regularizer(vw& all, bfgs& b, io_buf& model_file, bool read, bool text)
846 {
847 
848   char buff[512];
849   int c = 0;
850   uint32_t stride = 1 << all.reg.stride_shift;
851   uint32_t length = 2*(1 << all.num_bits);
852   uint32_t i = 0;
853   size_t brw = 1;
854   do
855     {
856       brw = 1;
857       weight* v;
858       if (read)
859 	{
860 	  c++;
861 	  brw = bin_read_fixed(model_file, (char*)&i, sizeof(i),"");
862 	  if (brw > 0)
863 	    {
864 	      assert (i< length);
865 	      v = &(b.regularizers[i]);
866 	      if (brw > 0)
867 		brw += bin_read_fixed(model_file, (char*)v, sizeof(*v), "");
868 	    }
869 	}
870       else // write binary or text
871 	{
872 	  v = &(b.regularizers[i]);
873 	  if (*v != 0.)
874 	    {
875 	      c++;
876 	      int text_len = sprintf(buff, "%d", i);
877 	      brw = bin_text_write_fixed(model_file,(char *)&i, sizeof (i),
878 					 buff, text_len, text);
879 
880 	      text_len = sprintf(buff, ":%f\n", *v);
881 	      brw+= bin_text_write_fixed(model_file,(char *)v, sizeof (*v),
882 					 buff, text_len, text);
883 	      if (read && i%2 == 1) // This is the prior mean
884 		all.reg.weight_vector[(i/2*stride)] = *v;
885 	    }
886 	}
887       if (!read)
888 	i++;
889     }
890   while ((!read && i < length) || (read && brw >0));
891 }
892 
893 
save_load(bfgs & b,io_buf & model_file,bool read,bool text)894 void save_load(bfgs& b, io_buf& model_file, bool read, bool text)
895 {
896   vw* all = b.all;
897 
898   uint32_t length = 1 << all->num_bits;
899 
900   if (read)
901     {
902       initialize_regressor(*all);
903       if (all->per_feature_regularizer_input != "")
904 	{
905 	  b.regularizers = calloc_or_die<weight>(2*length);
906 	  if (b.regularizers == NULL)
907 	    {
908 	      cerr << all->program_name << ": Failed to allocate regularizers array: try decreasing -b <bits>" << endl;
909 	      throw exception();
910 	    }
911 	}
912       int m = b.m;
913 
914       b.mem_stride = (m==0) ? CG_EXTRA : 2*m;
915       b.mem = (float*) malloc(sizeof(float)*all->length()*(b.mem_stride));
916       b.rho = (double*) malloc(sizeof(double)*m);
917       b.alpha = (double*) malloc(sizeof(double)*m);
918 
919       if (!all->quiet)
920 	{
921 	  fprintf(stderr, "m = %d\nAllocated %luM for weights and mem\n", m, (long unsigned int)all->length()*(sizeof(float)*(b.mem_stride)+(sizeof(weight) << all->reg.stride_shift)) >> 20);
922 	}
923 
924       b.net_time = 0.0;
925       ftime(&b.t_start_global);
926 
927       if (!all->quiet)
928 	{
929 	  const char * header_fmt = "%2s %-10s\t%-10s\t%-10s\t %-10s\t%-10s\t%-10s\t%-10s\t%-10s\t%-10s\n";
930 	  fprintf(stderr, header_fmt,
931 		  "##", "avg. loss", "der. mag.", "d. m. cond.", "wolfe1", "wolfe2", "mix fraction", "curvature", "dir. magnitude", "step size");
932 	  cerr.precision(5);
933 	}
934 
935       if (b.regularizers != NULL)
936 	all->l2_lambda = 1; // To make sure we are adding the regularization
937       b.output_regularizer =  (all->per_feature_regularizer_output != "" || all->per_feature_regularizer_text != "");
938       reset_state(*all, b, false);
939     }
940 
941   //bool reg_vector = b.output_regularizer || all->per_feature_regularizer_input.length() > 0;
942   bool reg_vector = (b.output_regularizer && !read) || (all->per_feature_regularizer_input.length() > 0 && read);
943 
944   if (model_file.files.size() > 0)
945     {
946       char buff[512];
947       uint32_t text_len = sprintf(buff, ":%d\n", reg_vector);
948       bin_text_read_write_fixed(model_file,(char *)&reg_vector, sizeof (reg_vector),
949 				"", read,
950 				buff, text_len, text);
951 
952       if (reg_vector)
953 	save_load_regularizer(*all, b, model_file, read, text);
954       else
955 	GD::save_load_regressor(*all, model_file, read, text);
956     }
957 }
958 
init_driver(bfgs & b)959   void init_driver(bfgs& b)
960   {
961     b.backstep_on = true;
962   }
963 
bfgs_setup(vw & all)964 base_learner* bfgs_setup(vw& all)
965 {
966   if (missing_option(all, false, "bfgs", "use bfgs optimization") &&
967       missing_option(all, false, "conjugate_gradient", "use conjugate gradient based optimization"))
968     return NULL;
969   new_options(all, "LBFGS options")
970     ("hessian_on", "use second derivative in line search")
971     ("mem", po::value<uint32_t>()->default_value(15), "memory in bfgs")
972     ("termination", po::value<float>()->default_value(0.001f),"Termination threshold");
973   add_options(all);
974 
975   po::variables_map& vm = all.vm;
976   bfgs& b = calloc_or_die<bfgs>();
977   b.all = &all;
978   b.m = vm["mem"].as<uint32_t>();
979   b.rel_threshold = vm["termination"].as<float>();
980   b.wolfe1_bound = 0.01;
981   b.first_hessian_on=true;
982   b.first_pass = true;
983   b.gradient_pass = true;
984   b.preconditioner_pass = true;
985   b.backstep_on = false;
986   b.final_pass=all.numpasses;
987   b.no_win_counter = 0;
988   b.early_stop_thres = 3;
989 
990   if(!all.holdout_set_off)
991   {
992     all.sd->holdout_best_loss = FLT_MAX;
993     if(vm.count("early_terminate"))
994       b.early_stop_thres = vm["early_terminate"].as< size_t>();
995   }
996 
997   if (vm.count("hessian_on") || b.m==0) {
998     all.hessian_on = true;
999   }
1000   if (!all.quiet) {
1001     if (b.m>0)
1002       cerr << "enabling BFGS based optimization ";
1003     else
1004       cerr << "enabling conjugate gradient optimization via BFGS ";
1005     if (all.hessian_on)
1006       cerr << "with curvature calculation" << endl;
1007     else
1008       cerr << "**without** curvature calculation" << endl;
1009   }
1010   if (all.numpasses < 2)
1011     {
1012       cout << "you must make at least 2 passes to use BFGS" << endl;
1013       throw exception();
1014     }
1015 
1016   all.bfgs = true;
1017   all.reg.stride_shift = 2;
1018 
1019   learner<bfgs>& l = init_learner(&b, learn, 1 << all.reg.stride_shift);
1020   l.set_predict(predict);
1021   l.set_save_load(save_load);
1022   l.set_init_driver(init_driver);
1023   l.set_end_pass(end_pass);
1024   l.set_finish(finish);
1025 
1026   return make_base(l);
1027 }
1028 
1029