1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5 */
6 /*
7 The algorithm here is generally based on Nocedal 1980, Liu and Nocedal 1989.
8 Implementation by Miro Dudik.
9 */
10 #include <fstream>
11 #include <float.h>
12 #include <exception>
13 #ifndef _WIN32
14 #include <netdb.h>
15 #endif
16 #include <string.h>
17 #include <stdio.h>
18 #include <assert.h>
19 #include <sys/timeb.h>
20 #include "accumulate.h"
21 #include "gd.h"
22
23 using namespace std;
24 using namespace LEARNER;
25
26 #define CG_EXTRA 1
27
28 #define MEM_GT 0
29 #define MEM_XT 1
30 #define MEM_YT 0
31 #define MEM_ST 1
32
33 #define W_XT 0
34 #define W_GT 1
35 #define W_DIR 2
36 #define W_COND 3
37
38 #define LEARN_OK 0
39 #define LEARN_CURV 1
40 #define LEARN_CONV 2
41
42 class curv_exception: public exception {} curv_ex;
43
44 /********************************************************************/
45 /* mem & w definition ***********************************************/
46 /********************************************************************/
47 // mem[2*i] = y_t
48 // mem[2*i+1] = s_t
49 //
50 // w[0] = weight
51 // w[1] = accumulated first derivative
52 // w[2] = step direction
53 // w[3] = preconditioner
54
55 const float max_precond_ratio = 10000.f;
56
57 struct bfgs {
58 vw* all;//prediction, regressor
59 int m;
60 float rel_threshold; // termination threshold
61
62 double wolfe1_bound;
63
64 size_t final_pass;
65 struct timeb t_start, t_end;
66 double net_comm_time;
67
68 struct timeb t_start_global, t_end_global;
69 double net_time;
70
71 v_array<float> predictions;
72 size_t example_number;
73 size_t current_pass;
74 size_t no_win_counter;
75 size_t early_stop_thres;
76
77 // default transition behavior
78 bool first_hessian_on;
79 bool backstep_on;
80
81 // set by initializer
82 int mem_stride;
83 bool output_regularizer;
84 float* mem;
85 double* rho;
86 double* alpha;
87
88 weight* regularizers;
89 // the below needs to be included when resetting, in addition to preconditioner and derivative
90 int lastj, origin;
91 double loss_sum, previous_loss_sum;
92 float step_size;
93 double importance_weight_sum;
94 double curvature;
95
96 // first pass specification
97 bool first_pass;
98 bool gradient_pass;
99 bool preconditioner_pass;
100 };
101
102 const char* curv_message = "Zero or negative curvature detected.\n"
103 "To increase curvature you can increase regularization or rescale features.\n"
104 "It is also possible that you have reached numerical accuracy\n"
105 "and further decrease in the objective cannot be reliably detected.\n";
106
zero_derivative(vw & all)107 void zero_derivative(vw& all)
108 {//set derivative to 0.
109 uint32_t length = 1 << all.num_bits;
110 size_t stride_shift = all.reg.stride_shift;
111 weight* weights = all.reg.weight_vector;
112 for(uint32_t i = 0; i < length; i++)
113 weights[(i << stride_shift) +W_GT] = 0;
114 }
115
zero_preconditioner(vw & all)116 void zero_preconditioner(vw& all)
117 {//set derivative to 0.
118 uint32_t length = 1 << all.num_bits;
119 size_t stride_shift = all.reg.stride_shift;
120 weight* weights = all.reg.weight_vector;
121 for(uint32_t i = 0; i < length; i++)
122 weights[(i << stride_shift)+W_COND] = 0;
123 }
124
reset_state(vw & all,bfgs & b,bool zero)125 void reset_state(vw& all, bfgs& b, bool zero)
126 {
127 b.lastj = b.origin = 0;
128 b.loss_sum = b.previous_loss_sum = 0.;
129 b.importance_weight_sum = 0.;
130 b.curvature = 0.;
131 b.first_pass = true;
132 b.gradient_pass = true;
133 b.preconditioner_pass = true;
134 if (zero)
135 {
136 zero_derivative(all);
137 zero_preconditioner(all);
138 }
139 }
140
141 // w[0] = weight
142 // w[1] = accumulated first derivative
143 // w[2] = step direction
144 // w[3] = preconditioner
145
test_example(example & ec)146 bool test_example(example& ec)
147 {
148 return ec.l.simple.label == FLT_MAX;
149 }
150
bfgs_predict(vw & all,example & ec)151 float bfgs_predict(vw& all, example& ec)
152 {
153 ec.partial_prediction = GD::inline_predict(all,ec);
154 return GD::finalize_prediction(all.sd, ec.partial_prediction);
155 }
156
add_grad(float & d,float f,float & fw)157 inline void add_grad(float& d, float f, float& fw)
158 {
159 fw += d * f;
160 }
161
predict_and_gradient(vw & all,example & ec)162 float predict_and_gradient(vw& all, example &ec)
163 {
164 float fp = bfgs_predict(all, ec);
165
166 label_data& ld = ec.l.simple;
167 all.set_minmax(all.sd, ld.label);
168
169 float loss_grad = all.loss->first_derivative(all.sd, fp,ld.label)*ld.weight;
170
171 ec.ft_offset += W_GT;
172 GD::foreach_feature<float,add_grad>(all, ec, loss_grad);
173 ec.ft_offset -= W_GT;
174
175 return fp;
176 }
177
add_precond(float & d,float f,float & fw)178 inline void add_precond(float& d, float f, float& fw)
179 {
180 fw += d * f * f;
181 }
182
update_preconditioner(vw & all,example & ec)183 void update_preconditioner(vw& all, example& ec)
184 {
185 label_data& ld = ec.l.simple;
186 float curvature = all.loss->second_derivative(all.sd, ec.pred.scalar, ld.label) * ld.weight;
187
188 ec.ft_offset += W_COND;
189 GD::foreach_feature<float,add_precond>(all, ec, curvature);
190 ec.ft_offset -= W_COND;
191 }
192
193
dot_with_direction(vw & all,example & ec)194 float dot_with_direction(vw& all, example& ec)
195 {
196 ec.ft_offset+= W_DIR;
197 float ret = GD::inline_predict(all, ec);
198 ec.ft_offset-= W_DIR;
199
200 return ret;
201 }
202
regularizer_direction_magnitude(vw & all,bfgs & b,float regularizer)203 double regularizer_direction_magnitude(vw& all, bfgs& b, float regularizer)
204 {//compute direction magnitude
205 double ret = 0.;
206
207 if (regularizer == 0.)
208 return ret;
209
210 uint32_t length = 1 << all.num_bits;
211 size_t stride_shift = all.reg.stride_shift;
212 weight* weights = all.reg.weight_vector;
213 if (b.regularizers == NULL)
214 for(uint32_t i = 0; i < length; i++)
215 ret += regularizer*weights[(i << stride_shift)+W_DIR]*weights[(i << stride_shift)+W_DIR];
216 else
217 for(uint32_t i = 0; i < length; i++)
218 ret += b.regularizers[2*i]*weights[(i << stride_shift)+W_DIR]*weights[(i << stride_shift)+W_DIR];
219
220 return ret;
221 }
222
direction_magnitude(vw & all)223 float direction_magnitude(vw& all)
224 {//compute direction magnitude
225 double ret = 0.;
226 uint32_t length = 1 << all.num_bits;
227 size_t stride_shift = all.reg.stride_shift;
228 weight* weights = all.reg.weight_vector;
229 for(uint32_t i = 0; i < length; i++)
230 ret += weights[(i << stride_shift)+W_DIR]*weights[(i << stride_shift)+W_DIR];
231
232 return (float)ret;
233 }
234
bfgs_iter_start(vw & all,bfgs & b,float * mem,int & lastj,double importance_weight_sum,int & origin)235 void bfgs_iter_start(vw& all, bfgs& b, float* mem, int& lastj, double importance_weight_sum, int&origin)
236 {
237 uint32_t length = 1 << all.num_bits;
238 size_t stride = 1 << all.reg.stride_shift;
239 weight* w = all.reg.weight_vector;
240
241 double g1_Hg1 = 0.;
242 double g1_g1 = 0.;
243
244 origin = 0;
245 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
246 if (b.m>0)
247 mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT];
248 mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
249 g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND];
250 g1_g1 += w[W_GT] * w[W_GT];
251 w[W_DIR] = -w[W_COND]*w[W_GT];
252 w[W_GT] = 0;
253 }
254 lastj = 0;
255 if (!all.quiet)
256 fprintf(stderr, "%-10.5f\t%-10.5f\t%-10s\t%-10s\t%-10s\t",
257 g1_g1/(importance_weight_sum*importance_weight_sum),
258 g1_Hg1/importance_weight_sum, "", "", "");
259 }
260
bfgs_iter_middle(vw & all,bfgs & b,float * mem,double * rho,double * alpha,int & lastj,int & origin)261 void bfgs_iter_middle(vw& all, bfgs& b, float* mem, double* rho, double* alpha, int& lastj, int &origin)
262 {
263 uint32_t length = 1 << all.num_bits;
264 size_t stride = 1 << all.reg.stride_shift;
265 weight* w = all.reg.weight_vector;
266
267 float* mem0 = mem;
268 float* w0 = w;
269
270 // implement conjugate gradient
271 if (b.m==0) {
272 double g_Hy = 0.;
273 double g_Hg = 0.;
274 double y = 0.;
275
276 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
277 y = w[W_GT]-mem[(MEM_GT+origin)%b.mem_stride];
278 g_Hy += w[W_GT] * w[W_COND] * y;
279 g_Hg += mem[(MEM_GT+origin)%b.mem_stride] * w[W_COND] * mem[(MEM_GT+origin)%b.mem_stride];
280 }
281
282 float beta = (float) (g_Hy/g_Hg);
283
284 if (beta<0.f || nanpattern(beta))
285 beta = 0.f;
286
287 mem = mem0;
288 w = w0;
289 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
290 mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
291
292 w[W_DIR] *= beta;
293 w[W_DIR] -= w[W_COND]*w[W_GT];
294 w[W_GT] = 0;
295 }
296 if (!all.quiet)
297 fprintf(stderr, "%f\t", beta);
298 return;
299 }
300 else {
301 if (!all.quiet)
302 fprintf(stderr, "%-10s\t","");
303 }
304
305 // implement bfgs
306 double y_s = 0.;
307 double y_Hy = 0.;
308 double s_q = 0.;
309
310 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
311 mem[(MEM_YT+origin)%b.mem_stride] = w[W_GT] - mem[(MEM_GT+origin)%b.mem_stride];
312 mem[(MEM_ST+origin)%b.mem_stride] = w[W_XT] - mem[(MEM_XT+origin)%b.mem_stride];
313 w[W_DIR] = w[W_GT];
314 y_s += mem[(MEM_YT+origin)%b.mem_stride]*mem[(MEM_ST+origin)%b.mem_stride];
315 y_Hy += mem[(MEM_YT+origin)%b.mem_stride]*mem[(MEM_YT+origin)%b.mem_stride]*w[W_COND];
316 s_q += mem[(MEM_ST+origin)%b.mem_stride]*w[W_GT];
317 }
318
319 if (y_s <= 0. || y_Hy <= 0.)
320 throw curv_ex;
321 rho[0] = 1/y_s;
322
323 float gamma = (float) (y_s/y_Hy);
324
325 for (int j=0; j<lastj; j++) {
326 alpha[j] = rho[j] * s_q;
327 s_q = 0.;
328 mem = mem0;
329 w = w0;
330 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
331 w[W_DIR] -= (float)alpha[j]*mem[(2*j+MEM_YT+origin)%b.mem_stride];
332 s_q += mem[(2*j+2+MEM_ST+origin)%b.mem_stride]*w[W_DIR];
333 }
334 }
335
336 alpha[lastj] = rho[lastj] * s_q;
337 double y_r = 0.;
338 mem = mem0;
339 w = w0;
340 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
341 w[W_DIR] -= (float)alpha[lastj]*mem[(2*lastj+MEM_YT+origin)%b.mem_stride];
342 w[W_DIR] *= gamma*w[W_COND];
343 y_r += mem[(2*lastj+MEM_YT+origin)%b.mem_stride]*w[W_DIR];
344 }
345
346 double coef_j;
347
348 for (int j=lastj; j>0; j--) {
349 coef_j = alpha[j] - rho[j] * y_r;
350 y_r = 0.;
351 mem = mem0;
352 w = w0;
353 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
354 w[W_DIR] += (float)coef_j*mem[(2*j+MEM_ST+origin)%b.mem_stride];
355 y_r += mem[(2*j-2+MEM_YT+origin)%b.mem_stride]*w[W_DIR];
356 }
357 }
358
359
360 coef_j = alpha[0] - rho[0] * y_r;
361 mem = mem0;
362 w = w0;
363 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
364 w[W_DIR] = -w[W_DIR]-(float)coef_j*mem[(MEM_ST+origin)%b.mem_stride];
365 }
366
367 /*********************
368 ** shift
369 ********************/
370
371 mem = mem0;
372 w = w0;
373 lastj = (lastj<b.m-1) ? lastj+1 : b.m-1;
374 origin = (origin+b.mem_stride-2)%b.mem_stride;
375 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
376 mem[(MEM_GT+origin)%b.mem_stride] = w[W_GT];
377 mem[(MEM_XT+origin)%b.mem_stride] = w[W_XT];
378 w[W_GT] = 0;
379 }
380 for (int j=lastj; j>0; j--)
381 rho[j] = rho[j-1];
382 }
383
wolfe_eval(vw & all,bfgs & b,float * mem,double loss_sum,double previous_loss_sum,double step_size,double importance_weight_sum,int & origin,double & wolfe1)384 double wolfe_eval(vw& all, bfgs& b, float* mem, double loss_sum, double previous_loss_sum, double step_size, double importance_weight_sum, int &origin, double& wolfe1) {
385 uint32_t length = 1 << all.num_bits;
386 size_t stride = 1 << all.reg.stride_shift;
387 weight* w = all.reg.weight_vector;
388
389 double g0_d = 0.;
390 double g1_d = 0.;
391 double g1_Hg1 = 0.;
392 double g1_g1 = 0.;
393
394 for(uint32_t i = 0; i < length; i++, mem+=b.mem_stride, w+=stride) {
395 g0_d += mem[(MEM_GT+origin)%b.mem_stride] * w[W_DIR];
396 g1_d += w[W_GT] * w[W_DIR];
397 g1_Hg1 += w[W_GT] * w[W_GT] * w[W_COND];
398 g1_g1 += w[W_GT] * w[W_GT];
399 }
400
401 wolfe1 = (loss_sum-previous_loss_sum)/(step_size*g0_d);
402 double wolfe2 = g1_d/g0_d;
403 // double new_step_cross = (loss_sum-previous_loss_sum-g1_d*step)/(g0_d-g1_d);
404
405 if (!all.quiet)
406 fprintf(stderr, "%-10.5f\t%-10.5f\t%s%-10f\t%-10f\t", g1_g1/(importance_weight_sum*importance_weight_sum), g1_Hg1/importance_weight_sum, " ", wolfe1, wolfe2);
407 return 0.5*step_size;
408 }
409
410
add_regularization(vw & all,bfgs & b,float regularization)411 double add_regularization(vw& all, bfgs& b, float regularization)
412 {//compute the derivative difference
413 double ret = 0.;
414 uint32_t length = 1 << all.num_bits;
415 size_t stride_shift = all.reg.stride_shift;
416 weight* weights = all.reg.weight_vector;
417 if (b.regularizers == NULL)
418 {
419 for(uint32_t i = 0; i < length; i++) {
420 weights[(i << stride_shift)+W_GT] += regularization*weights[i << stride_shift];
421 ret += 0.5*regularization*weights[i << stride_shift]*weights[i << stride_shift];
422 }
423 }
424 else
425 {
426 for(uint32_t i = 0; i < length; i++) {
427 weight delta_weight = weights[i << stride_shift] - b.regularizers[2*i+1];
428 weights[(i << stride_shift)+W_GT] += b.regularizers[2*i]*delta_weight;
429 ret += 0.5*b.regularizers[2*i]*delta_weight*delta_weight;
430 }
431 }
432
433 return ret;
434 }
435
finalize_preconditioner(vw & all,bfgs & b,float regularization)436 void finalize_preconditioner(vw& all, bfgs& b, float regularization)
437 {
438 uint32_t length = 1 << all.num_bits;
439 size_t stride = 1 << all.reg.stride_shift;
440 weight* weights = all.reg.weight_vector;
441 float max_hessian = 0.f;
442
443 if (b.regularizers == NULL)
444 for(uint32_t i = 0; i < length; i++) {
445 weights[stride*i+W_COND] += regularization;
446 if (weights[stride*i+W_COND] > max_hessian)
447 max_hessian = weights[stride*i+W_COND];
448 if (weights[stride*i+W_COND] > 0)
449 weights[stride*i+W_COND] = 1.f / weights[stride*i+W_COND];
450 }
451 else
452 for(uint32_t i = 0; i < length; i++) {
453 weights[stride*i+W_COND] += b.regularizers[2*i];
454 if (weights[stride*i+W_COND] > max_hessian)
455 max_hessian = weights[stride*i+W_COND];
456 if (weights[stride*i+W_COND] > 0)
457 weights[stride*i+W_COND] = 1.f / weights[stride*i+W_COND];
458 }
459
460 float max_precond = (max_hessian==0.f) ? 0.f : max_precond_ratio / max_hessian;
461 weights = all.reg.weight_vector;
462 for(uint32_t i = 0; i < length; i++) {
463 if (infpattern(weights[stride*i+W_COND]) || weights[stride*i+W_COND]>max_precond)
464 weights[stride*i+W_COND] = max_precond;
465 }
466 }
467
preconditioner_to_regularizer(vw & all,bfgs & b,float regularization)468 void preconditioner_to_regularizer(vw& all, bfgs& b, float regularization)
469 {
470 uint32_t length = 1 << all.num_bits;
471 size_t stride = 1 << all.reg.stride_shift;
472 weight* weights = all.reg.weight_vector;
473 if (b.regularizers == NULL)
474 {
475 b.regularizers = calloc_or_die<weight>(2*length);
476
477 if (b.regularizers == NULL)
478 {
479 cerr << all.program_name << ": Failed to allocate weight array: try decreasing -b <bits>" << endl;
480 throw exception();
481 }
482 for(uint32_t i = 0; i < length; i++)
483 b.regularizers[2*i] = weights[stride*i+W_COND] + regularization;
484 }
485 else
486 for(uint32_t i = 0; i < length; i++)
487 b.regularizers[2*i] = weights[stride*i+W_COND] + b.regularizers[2*i];
488 for(uint32_t i = 0; i < length; i++)
489 b.regularizers[2*i+1] = weights[stride*i];
490 }
491
zero_state(vw & all)492 void zero_state(vw& all)
493 {
494 uint32_t length = 1 << all.num_bits;
495 size_t stride = 1 << all.reg.stride_shift;
496 weight* weights = all.reg.weight_vector;
497 for(uint32_t i = 0; i < length; i++)
498 {
499 weights[stride*i+W_GT] = 0;
500 weights[stride*i+W_DIR] = 0;
501 weights[stride*i+W_COND] = 0;
502 }
503 }
504
derivative_in_direction(vw & all,bfgs & b,float * mem,int & origin)505 double derivative_in_direction(vw& all, bfgs& b, float* mem, int &origin)
506 {
507 double ret = 0.;
508 uint32_t length = 1 << all.num_bits;
509 size_t stride = 1 << all.reg.stride_shift;
510 weight* w = all.reg.weight_vector;
511
512 for(uint32_t i = 0; i < length; i++, w+=stride, mem+=b.mem_stride)
513 ret += mem[(MEM_GT+origin)%b.mem_stride]*w[W_DIR];
514 return ret;
515 }
516
update_weight(vw & all,float step_size,size_t current_pass)517 void update_weight(vw& all, float step_size, size_t current_pass)
518 {
519 uint32_t length = 1 << all.num_bits;
520 size_t stride = 1 << all.reg.stride_shift;
521 weight* w = all.reg.weight_vector;
522
523 for(uint32_t i = 0; i < length; i++, w+=stride)
524 w[W_XT] += step_size * w[W_DIR];
525 }
526
process_pass(vw & all,bfgs & b)527 int process_pass(vw& all, bfgs& b) {
528 int status = LEARN_OK;
529
530 /********************************************************************/
531 /* A) FIRST PASS FINISHED: INITIALIZE FIRST LINE SEARCH *************/
532 /********************************************************************/
533 if (b.first_pass) {
534 if(all.span_server != "")
535 {
536 accumulate(all, all.span_server, all.reg, W_COND); //Accumulate preconditioner
537 float temp = (float)b.importance_weight_sum;
538 b.importance_weight_sum = accumulate_scalar(all, all.span_server, temp);
539 }
540 finalize_preconditioner(all, b, all.l2_lambda);
541 if(all.span_server != "") {
542 float temp = (float)b.loss_sum;
543 b.loss_sum = accumulate_scalar(all, all.span_server, temp); //Accumulate loss_sums
544 accumulate(all, all.span_server, all.reg, 1); //Accumulate gradients from all nodes
545 }
546 if (all.l2_lambda > 0.)
547 b.loss_sum += add_regularization(all, b, all.l2_lambda);
548 if (!all.quiet)
549 fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)b.current_pass+1, b.loss_sum / b.importance_weight_sum);
550
551 b.previous_loss_sum = b.loss_sum;
552 b.loss_sum = 0.;
553 b.example_number = 0;
554 b.curvature = 0;
555 bfgs_iter_start(all, b, b.mem, b.lastj, b.importance_weight_sum, b.origin);
556 if (b.first_hessian_on) {
557 b.gradient_pass = false;//now start computing curvature
558 }
559 else {
560 b.step_size = 0.5;
561 float d_mag = direction_magnitude(all);
562 ftime(&b.t_end_global);
563 b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
564 if (!all.quiet)
565 fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, b.step_size);
566 b.predictions.erase();
567 update_weight(all, b.step_size, b.current_pass); }
568 }
569 else
570 /********************************************************************/
571 /* B) GRADIENT CALCULATED *******************************************/
572 /********************************************************************/
573 if (b.gradient_pass) // We just finished computing all gradients
574 {
575 if(all.span_server != "") {
576 float t = (float)b.loss_sum;
577 b.loss_sum = accumulate_scalar(all, all.span_server, t); //Accumulate loss_sums
578 accumulate(all, all.span_server, all.reg, 1); //Accumulate gradients from all nodes
579 }
580 if (all.l2_lambda > 0.)
581 b.loss_sum += add_regularization(all, b, all.l2_lambda);
582 if (!all.quiet){
583 if(!all.holdout_set_off && b.current_pass >= 1){
584 if(all.sd->holdout_sum_loss_since_last_pass == 0. && all.sd->weighted_holdout_examples_since_last_pass == 0.){
585 fprintf(stderr, "%2lu ", (long unsigned int)b.current_pass+1);
586 fprintf(stderr, "h unknown ");
587 }
588 else
589 fprintf(stderr, "%2lu h%-10.5f\t", (long unsigned int)b.current_pass+1, all.sd->holdout_sum_loss_since_last_pass / all.sd->weighted_holdout_examples_since_last_pass);
590 }
591 else
592 fprintf(stderr, "%2lu %-10.5f\t", (long unsigned int)b.current_pass+1, b.loss_sum / b.importance_weight_sum);
593 }
594 double wolfe1;
595 double new_step = wolfe_eval(all, b, b.mem, b.loss_sum, b.previous_loss_sum, b.step_size, b.importance_weight_sum, b.origin, wolfe1);
596
597 /********************************************************************/
598 /* B0) DERIVATIVE ZERO: MINIMUM FOUND *******************************/
599 /********************************************************************/
600 if (nanpattern((float)wolfe1))
601 {
602 fprintf(stderr, "\n");
603 fprintf(stdout, "Derivative 0 detected.\n");
604 b.step_size=0.0;
605 status = LEARN_CONV;
606 }
607 /********************************************************************/
608 /* B1) LINE SEARCH FAILED *******************************************/
609 /********************************************************************/
610 else if (b.backstep_on && (wolfe1<b.wolfe1_bound || b.loss_sum > b.previous_loss_sum))
611 {// curvature violated, or we stepped too far last time: step back
612 ftime(&b.t_end_global);
613 b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
614 float ratio = (b.step_size==0.f) ? 0.f : (float)new_step/(float)b.step_size;
615 if (!all.quiet)
616 fprintf(stderr, "%-10s\t%-10s\t(revise x %.1f)\t%-10.5f\n",
617 "","",ratio,
618 new_step);
619 b.predictions.erase();
620 update_weight(all, (float)(-b.step_size+new_step), b.current_pass);
621 b.step_size = (float)new_step;
622 zero_derivative(all);
623 b.loss_sum = 0.;
624 }
625
626 /********************************************************************/
627 /* B2) LINE SEARCH SUCCESSFUL OR DISABLED ******************/
628 /* DETERMINE NEXT SEARCH DIRECTION ******************/
629 /********************************************************************/
630 else {
631 double rel_decrease = (b.previous_loss_sum-b.loss_sum)/b.previous_loss_sum;
632 if (!nanpattern((float)rel_decrease) && b.backstep_on && fabs(rel_decrease)<b.rel_threshold) {
633 fprintf(stdout, "\nTermination condition reached in pass %ld: decrease in loss less than %.3f%%.\n"
634 "If you want to optimize further, decrease termination threshold.\n", (long int)b.current_pass+1, b.rel_threshold*100.0);
635 status = LEARN_CONV;
636 }
637 b.previous_loss_sum = b.loss_sum;
638 b.loss_sum = 0.;
639 b.example_number = 0;
640 b.curvature = 0;
641 b.step_size = 1.0;
642
643 try {
644 bfgs_iter_middle(all, b, b.mem, b.rho, b.alpha, b.lastj, b.origin);
645 }
646 catch (curv_exception e) {
647 fprintf(stdout, "In bfgs_iter_middle: %s", curv_message);
648 b.step_size=0.0;
649 status = LEARN_CURV;
650 }
651
652 if (all.hessian_on) {
653 b.gradient_pass = false;//now start computing curvature
654 }
655 else {
656 float d_mag = direction_magnitude(all);
657 ftime(&b.t_end_global);
658 b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
659 if (!all.quiet)
660 fprintf(stderr, "%-10s\t%-10.5f\t%-10.5f\n", "", d_mag, b.step_size);
661 b.predictions.erase();
662 update_weight(all, b.step_size, b.current_pass);
663 }
664 }
665 }
666
667 /********************************************************************/
668 /* C) NOT FIRST PASS, CURVATURE CALCULATED **************************/
669 /********************************************************************/
670 else // just finished all second gradients
671 {
672 if(all.span_server != "") {
673 float t = (float)b.curvature;
674 b.curvature = accumulate_scalar(all, all.span_server, t); //Accumulate curvatures
675 }
676 if (all.l2_lambda > 0.)
677 b.curvature += regularizer_direction_magnitude(all, b, all.l2_lambda);
678 float dd = (float)derivative_in_direction(all, b, b.mem, b.origin);
679 if (b.curvature == 0. && dd != 0.)
680 {
681 fprintf(stdout, "%s", curv_message);
682 b.step_size=0.0;
683 status = LEARN_CURV;
684 }
685 else if ( dd == 0.)
686 {
687 fprintf(stdout, "Derivative 0 detected.\n");
688 b.step_size=0.0;
689 status = LEARN_CONV;
690 }
691 else
692 b.step_size = - dd/(float)b.curvature;
693
694 float d_mag = direction_magnitude(all);
695
696 b.predictions.erase();
697 update_weight(all, b.step_size, b.current_pass);
698 ftime(&b.t_end_global);
699 b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
700 if (!all.quiet)
701 fprintf(stderr, "%-10.5f\t%-10.5f\t%-10.5f\n", b.curvature / b.importance_weight_sum, d_mag, b.step_size);
702 b.gradient_pass = true;
703 }//now start computing derivatives.
704 b.current_pass++;
705 b.first_pass = false;
706 b.preconditioner_pass = false;
707
708 if (b.output_regularizer)//need to accumulate and place the regularizer.
709 {
710 if(all.span_server != "")
711 accumulate(all, all.span_server, all.reg, W_COND); //Accumulate preconditioner
712 //preconditioner_to_regularizer(all, b, all.l2_lambda);
713 }
714 ftime(&b.t_end_global);
715 b.net_time = (int) (1000.0 * (b.t_end_global.time - b.t_start_global.time) + (b.t_end_global.millitm - b.t_start_global.millitm));
716
717 if (all.save_per_pass)
718 save_predictor(all, all.final_regressor_name, b.current_pass);
719 return status;
720 }
721
process_example(vw & all,bfgs & b,example & ec)722 void process_example(vw& all, bfgs& b, example& ec)
723 {
724 label_data& ld = ec.l.simple;
725 if (b.first_pass)
726 b.importance_weight_sum += ld.weight;
727
728 /********************************************************************/
729 /* I) GRADIENT CALCULATION ******************************************/
730 /********************************************************************/
731 if (b.gradient_pass)
732 {
733 ec.pred.scalar = predict_and_gradient(all, ec);//w[0] & w[1]
734 ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight;
735 b.loss_sum += ec.loss;
736 b.predictions.push_back(ec.pred.scalar);
737 }
738 /********************************************************************/
739 /* II) CURVATURE CALCULATION ****************************************/
740 /********************************************************************/
741 else //computing curvature
742 {
743 float d_dot_x = dot_with_direction(all, ec);//w[2]
744 if (b.example_number >= b.predictions.size())//Make things safe in case example source is strange.
745 b.example_number = b.predictions.size()-1;
746 ec.pred.scalar = b.predictions[b.example_number];
747 ec.partial_prediction = b.predictions[b.example_number];
748 ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight;
749 float sd = all.loss->second_derivative(all.sd, b.predictions[b.example_number++],ld.label);
750 b.curvature += d_dot_x*d_dot_x*sd*ld.weight;
751 }
752 ec.updated_prediction = ec.pred.scalar;
753
754 if (b.preconditioner_pass)
755 update_preconditioner(all, ec);//w[3]
756 }
757
end_pass(bfgs & b)758 void end_pass(bfgs& b)
759 {
760 vw* all = b.all;
761
762 if (b.current_pass <= b.final_pass)
763 {
764 if(b.current_pass < b.final_pass)
765 {
766 int status = process_pass(*all, b);
767
768 //reaching the max number of passes regardless of convergence
769 if(b.final_pass == b.current_pass)
770 {
771 cerr<<"Maximum number of passes reached. ";
772 if(!b.output_regularizer)
773 cerr<<"If you want to optimize further, increase the number of passes\n";
774 if(b.output_regularizer)
775 {
776 cerr<<"\nRegular model file has been created. ";
777 cerr<<"Output feature regularizer file is created only when the convergence is reached. Try increasing the number of passes for convergence\n";
778 b.output_regularizer = false;
779 }
780
781 }
782
783 //attain convergence before reaching max iterations
784 if (status != LEARN_OK && b.final_pass > b.current_pass) {
785 b.final_pass = b.current_pass;
786 }
787
788 if (b.output_regularizer && b.final_pass == b.current_pass) {
789 zero_preconditioner(*all);
790 b.preconditioner_pass = true;
791 }
792
793 if(!all->holdout_set_off)
794 {
795 if(summarize_holdout_set(*all, b.no_win_counter))
796 finalize_regressor(*all, all->final_regressor_name);
797 if(b.early_stop_thres == b.no_win_counter)
798 {
799 set_done(*all);
800 cerr<<"Early termination reached w.r.t. holdout set error";
801 }
802 } if (b.final_pass == b.current_pass) {
803 finalize_regressor(*all, all->final_regressor_name);
804 set_done(*all);
805 }
806
807 }else{//reaching convergence in the previous pass
808 if(b.output_regularizer)
809 preconditioner_to_regularizer(*all, b, (*all).l2_lambda);
810 b.current_pass ++;
811 }
812
813 }
814 }
815
816 // placeholder
predict(bfgs & b,base_learner & base,example & ec)817 void predict(bfgs& b, base_learner& base, example& ec)
818 {
819 vw* all = b.all;
820 ec.pred.scalar = bfgs_predict(*all,ec);
821 }
822
learn(bfgs & b,base_learner & base,example & ec)823 void learn(bfgs& b, base_learner& base, example& ec)
824 {
825 vw* all = b.all;
826 assert(ec.in_use);
827
828 if (b.current_pass <= b.final_pass)
829 {
830 if (test_example(ec))
831 predict(b, base, ec);
832 else
833 process_example(*all, b, ec);
834 }
835 }
836
finish(bfgs & b)837 void finish(bfgs& b)
838 {
839 b.predictions.delete_v();
840 free(b.mem);
841 free(b.rho);
842 free(b.alpha);
843 }
844
save_load_regularizer(vw & all,bfgs & b,io_buf & model_file,bool read,bool text)845 void save_load_regularizer(vw& all, bfgs& b, io_buf& model_file, bool read, bool text)
846 {
847
848 char buff[512];
849 int c = 0;
850 uint32_t stride = 1 << all.reg.stride_shift;
851 uint32_t length = 2*(1 << all.num_bits);
852 uint32_t i = 0;
853 size_t brw = 1;
854 do
855 {
856 brw = 1;
857 weight* v;
858 if (read)
859 {
860 c++;
861 brw = bin_read_fixed(model_file, (char*)&i, sizeof(i),"");
862 if (brw > 0)
863 {
864 assert (i< length);
865 v = &(b.regularizers[i]);
866 if (brw > 0)
867 brw += bin_read_fixed(model_file, (char*)v, sizeof(*v), "");
868 }
869 }
870 else // write binary or text
871 {
872 v = &(b.regularizers[i]);
873 if (*v != 0.)
874 {
875 c++;
876 int text_len = sprintf(buff, "%d", i);
877 brw = bin_text_write_fixed(model_file,(char *)&i, sizeof (i),
878 buff, text_len, text);
879
880 text_len = sprintf(buff, ":%f\n", *v);
881 brw+= bin_text_write_fixed(model_file,(char *)v, sizeof (*v),
882 buff, text_len, text);
883 if (read && i%2 == 1) // This is the prior mean
884 all.reg.weight_vector[(i/2*stride)] = *v;
885 }
886 }
887 if (!read)
888 i++;
889 }
890 while ((!read && i < length) || (read && brw >0));
891 }
892
893
save_load(bfgs & b,io_buf & model_file,bool read,bool text)894 void save_load(bfgs& b, io_buf& model_file, bool read, bool text)
895 {
896 vw* all = b.all;
897
898 uint32_t length = 1 << all->num_bits;
899
900 if (read)
901 {
902 initialize_regressor(*all);
903 if (all->per_feature_regularizer_input != "")
904 {
905 b.regularizers = calloc_or_die<weight>(2*length);
906 if (b.regularizers == NULL)
907 {
908 cerr << all->program_name << ": Failed to allocate regularizers array: try decreasing -b <bits>" << endl;
909 throw exception();
910 }
911 }
912 int m = b.m;
913
914 b.mem_stride = (m==0) ? CG_EXTRA : 2*m;
915 b.mem = (float*) malloc(sizeof(float)*all->length()*(b.mem_stride));
916 b.rho = (double*) malloc(sizeof(double)*m);
917 b.alpha = (double*) malloc(sizeof(double)*m);
918
919 if (!all->quiet)
920 {
921 fprintf(stderr, "m = %d\nAllocated %luM for weights and mem\n", m, (long unsigned int)all->length()*(sizeof(float)*(b.mem_stride)+(sizeof(weight) << all->reg.stride_shift)) >> 20);
922 }
923
924 b.net_time = 0.0;
925 ftime(&b.t_start_global);
926
927 if (!all->quiet)
928 {
929 const char * header_fmt = "%2s %-10s\t%-10s\t%-10s\t %-10s\t%-10s\t%-10s\t%-10s\t%-10s\t%-10s\n";
930 fprintf(stderr, header_fmt,
931 "##", "avg. loss", "der. mag.", "d. m. cond.", "wolfe1", "wolfe2", "mix fraction", "curvature", "dir. magnitude", "step size");
932 cerr.precision(5);
933 }
934
935 if (b.regularizers != NULL)
936 all->l2_lambda = 1; // To make sure we are adding the regularization
937 b.output_regularizer = (all->per_feature_regularizer_output != "" || all->per_feature_regularizer_text != "");
938 reset_state(*all, b, false);
939 }
940
941 //bool reg_vector = b.output_regularizer || all->per_feature_regularizer_input.length() > 0;
942 bool reg_vector = (b.output_regularizer && !read) || (all->per_feature_regularizer_input.length() > 0 && read);
943
944 if (model_file.files.size() > 0)
945 {
946 char buff[512];
947 uint32_t text_len = sprintf(buff, ":%d\n", reg_vector);
948 bin_text_read_write_fixed(model_file,(char *)®_vector, sizeof (reg_vector),
949 "", read,
950 buff, text_len, text);
951
952 if (reg_vector)
953 save_load_regularizer(*all, b, model_file, read, text);
954 else
955 GD::save_load_regressor(*all, model_file, read, text);
956 }
957 }
958
init_driver(bfgs & b)959 void init_driver(bfgs& b)
960 {
961 b.backstep_on = true;
962 }
963
bfgs_setup(vw & all)964 base_learner* bfgs_setup(vw& all)
965 {
966 if (missing_option(all, false, "bfgs", "use bfgs optimization") &&
967 missing_option(all, false, "conjugate_gradient", "use conjugate gradient based optimization"))
968 return NULL;
969 new_options(all, "LBFGS options")
970 ("hessian_on", "use second derivative in line search")
971 ("mem", po::value<uint32_t>()->default_value(15), "memory in bfgs")
972 ("termination", po::value<float>()->default_value(0.001f),"Termination threshold");
973 add_options(all);
974
975 po::variables_map& vm = all.vm;
976 bfgs& b = calloc_or_die<bfgs>();
977 b.all = &all;
978 b.m = vm["mem"].as<uint32_t>();
979 b.rel_threshold = vm["termination"].as<float>();
980 b.wolfe1_bound = 0.01;
981 b.first_hessian_on=true;
982 b.first_pass = true;
983 b.gradient_pass = true;
984 b.preconditioner_pass = true;
985 b.backstep_on = false;
986 b.final_pass=all.numpasses;
987 b.no_win_counter = 0;
988 b.early_stop_thres = 3;
989
990 if(!all.holdout_set_off)
991 {
992 all.sd->holdout_best_loss = FLT_MAX;
993 if(vm.count("early_terminate"))
994 b.early_stop_thres = vm["early_terminate"].as< size_t>();
995 }
996
997 if (vm.count("hessian_on") || b.m==0) {
998 all.hessian_on = true;
999 }
1000 if (!all.quiet) {
1001 if (b.m>0)
1002 cerr << "enabling BFGS based optimization ";
1003 else
1004 cerr << "enabling conjugate gradient optimization via BFGS ";
1005 if (all.hessian_on)
1006 cerr << "with curvature calculation" << endl;
1007 else
1008 cerr << "**without** curvature calculation" << endl;
1009 }
1010 if (all.numpasses < 2)
1011 {
1012 cout << "you must make at least 2 passes to use BFGS" << endl;
1013 throw exception();
1014 }
1015
1016 all.bfgs = true;
1017 all.reg.stride_shift = 2;
1018
1019 learner<bfgs>& l = init_learner(&b, learn, 1 << all.reg.stride_shift);
1020 l.set_predict(predict);
1021 l.set_save_load(save_load);
1022 l.set_init_driver(init_driver);
1023 l.set_end_pass(end_pass);
1024 l.set_finish(finish);
1025
1026 return make_base(l);
1027 }
1028
1029