1 #include "best_constant.h"
2 
3 bool  is_more_than_two_labels_observed = false;
4 float first_observed_label = FLT_MAX;
5 float second_observed_label = FLT_MAX;
6 
get_best_constant(vw & all,float & best_constant,float & best_constant_loss)7 bool get_best_constant(vw& all, float& best_constant, float& best_constant_loss)
8 {
9     if (    first_observed_label == FLT_MAX || // no non-test labels observed or function was never called
10             (all.loss == NULL) || (all.sd == NULL)) return false;
11 
12     float label1 = first_observed_label; // observed labels might be inside [sd->Min_label, sd->Max_label], so can't use Min/Max
13     float label2 = (second_observed_label == FLT_MAX)?0:second_observed_label; // if only one label observed, second might be 0
14     if (label1 > label2) {float tmp = label1; label1 = label2; label2 = tmp;} // as don't use min/max - make sure label1 < label2
15 
16     float label1_cnt;
17     float label2_cnt;
18 
19     if (label1 != label2)
20     {
21         float weighted_labeled_examples = (float)(all.sd->weighted_examples - all.sd->weighted_unlabeled_examples + all.initial_t);
22         label1_cnt = (float) (all.sd->weighted_labels - label2*weighted_labeled_examples)/(label1 - label2);
23         label2_cnt = weighted_labeled_examples - label1_cnt;
24     } else
25         return false;
26 
27     if ( (label1_cnt + label2_cnt) <= 0.) return false;
28 
29     po::parsed_options pos = po::command_line_parser(all.args).
30             style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
31             options(all.opts).allow_unregistered().run();
32 
33     po::variables_map vm = po::variables_map();
34 
35     po::store(pos, vm);
36     po::notify(vm);
37 
38     string funcName;
39     if(vm.count("loss_function"))
40         funcName = vm["loss_function"].as<string>();
41     else
42         funcName = "squared";
43 
44     if(funcName.compare("squared") == 0 || funcName.compare("Huber") == 0 || funcName.compare("classic") == 0)
45     {
46         best_constant = (float) all.sd->weighted_labels / (float) (all.sd->weighted_examples - all.sd->weighted_unlabeled_examples + all.initial_t); //GENERIC. WAS: (label1*label1_cnt + label2*label2_cnt) / (label1_cnt + label2_cnt);
47 
48     } else if (is_more_than_two_labels_observed) {
49         //loss functions below don't have generic formuas for constant yet.
50         return false;
51 
52     } else if(funcName.compare("hinge") == 0) {
53 
54         best_constant = label2_cnt <= label1_cnt ? -1.f: 1.f;
55 
56     } else if(funcName.compare("logistic") == 0) {
57 
58         label1 = -1.; //override {-50, 50} to get proper loss
59         label2 =  1.;
60 
61         if (label1_cnt <= 0) best_constant = 1.;
62         else
63             if (label2_cnt <= 0) best_constant = -1.;
64             else
65                 best_constant = log(label2_cnt/label1_cnt);
66 
67     } else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) {
68 
69         float tau = 0.5;
70         if(vm.count("quantile_tau"))
71             tau = vm["quantile_tau"].as<float>();
72 
73         float q = tau*(label1_cnt + label2_cnt);
74         if (q < label2_cnt) best_constant = label2;
75         else best_constant = label1;
76     } else
77         return false;
78 
79     if (!is_more_than_two_labels_observed)
80     best_constant_loss = ( all.loss->getLoss(all.sd, best_constant, label1) * label1_cnt +
81                            all.loss->getLoss(all.sd, best_constant, label2) * label2_cnt )
82             / (label1_cnt + label2_cnt);
83     else best_constant_loss = FLT_MIN;
84 
85     return true;
86 }
87