1 #include "best_constant.h"
2
3 bool is_more_than_two_labels_observed = false;
4 float first_observed_label = FLT_MAX;
5 float second_observed_label = FLT_MAX;
6
get_best_constant(vw & all,float & best_constant,float & best_constant_loss)7 bool get_best_constant(vw& all, float& best_constant, float& best_constant_loss)
8 {
9 if ( first_observed_label == FLT_MAX || // no non-test labels observed or function was never called
10 (all.loss == NULL) || (all.sd == NULL)) return false;
11
12 float label1 = first_observed_label; // observed labels might be inside [sd->Min_label, sd->Max_label], so can't use Min/Max
13 float label2 = (second_observed_label == FLT_MAX)?0:second_observed_label; // if only one label observed, second might be 0
14 if (label1 > label2) {float tmp = label1; label1 = label2; label2 = tmp;} // as don't use min/max - make sure label1 < label2
15
16 float label1_cnt;
17 float label2_cnt;
18
19 if (label1 != label2)
20 {
21 float weighted_labeled_examples = (float)(all.sd->weighted_examples - all.sd->weighted_unlabeled_examples + all.initial_t);
22 label1_cnt = (float) (all.sd->weighted_labels - label2*weighted_labeled_examples)/(label1 - label2);
23 label2_cnt = weighted_labeled_examples - label1_cnt;
24 } else
25 return false;
26
27 if ( (label1_cnt + label2_cnt) <= 0.) return false;
28
29 po::parsed_options pos = po::command_line_parser(all.args).
30 style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
31 options(all.opts).allow_unregistered().run();
32
33 po::variables_map vm = po::variables_map();
34
35 po::store(pos, vm);
36 po::notify(vm);
37
38 string funcName;
39 if(vm.count("loss_function"))
40 funcName = vm["loss_function"].as<string>();
41 else
42 funcName = "squared";
43
44 if(funcName.compare("squared") == 0 || funcName.compare("Huber") == 0 || funcName.compare("classic") == 0)
45 {
46 best_constant = (float) all.sd->weighted_labels / (float) (all.sd->weighted_examples - all.sd->weighted_unlabeled_examples + all.initial_t); //GENERIC. WAS: (label1*label1_cnt + label2*label2_cnt) / (label1_cnt + label2_cnt);
47
48 } else if (is_more_than_two_labels_observed) {
49 //loss functions below don't have generic formuas for constant yet.
50 return false;
51
52 } else if(funcName.compare("hinge") == 0) {
53
54 best_constant = label2_cnt <= label1_cnt ? -1.f: 1.f;
55
56 } else if(funcName.compare("logistic") == 0) {
57
58 label1 = -1.; //override {-50, 50} to get proper loss
59 label2 = 1.;
60
61 if (label1_cnt <= 0) best_constant = 1.;
62 else
63 if (label2_cnt <= 0) best_constant = -1.;
64 else
65 best_constant = log(label2_cnt/label1_cnt);
66
67 } else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) {
68
69 float tau = 0.5;
70 if(vm.count("quantile_tau"))
71 tau = vm["quantile_tau"].as<float>();
72
73 float q = tau*(label1_cnt + label2_cnt);
74 if (q < label2_cnt) best_constant = label2;
75 else best_constant = label1;
76 } else
77 return false;
78
79 if (!is_more_than_two_labels_observed)
80 best_constant_loss = ( all.loss->getLoss(all.sd, best_constant, label1) * label1_cnt +
81 all.loss->getLoss(all.sd, best_constant, label2) * label2_cnt )
82 / (label1_cnt + label2_cnt);
83 else best_constant_loss = FLT_MIN;
84
85 return true;
86 }
87