1 #include "float.h"
2 #include "gd.h"
3 #include "vw.h"
4 
5 namespace COST_SENSITIVE {
6 
name_value(substring & s,v_array<substring> & name,float & v)7   void name_value(substring &s, v_array<substring>& name, float &v)
8   {
9     tokenize(':', s, name);
10 
11     switch (name.size()) {
12     case 0:
13     case 1:
14       v = 1.;
15       break;
16     case 2:
17       v = float_of_substring(name[1]);
18       if ( nanpattern(v))
19 	{
20 	  cerr << "error NaN value for: ";
21 	  cerr.write(name[0].begin, name[0].end - name[0].begin);
22 	  cerr << " terminating." << endl;
23 	  throw exception();
24 	}
25       break;
26     default:
27       cerr << "example with a wierd name.  What is '";
28       cerr.write(s.begin, s.end - s.begin);
29       cerr << "'?\n";
30     }
31   }
32 
is_test_label(label & ld)33   bool is_test_label(label& ld)
34   {
35     if (ld.costs.size() == 0)
36       return true;
37     for (unsigned int i=0; i<ld.costs.size(); i++)
38       if (FLT_MAX != ld.costs[i].x)
39         return false;
40     return true;
41   }
42 
bufread_label(label * ld,char * c,io_buf & cache)43   char* bufread_label(label* ld, char* c, io_buf& cache)
44   {
45     size_t num = *(size_t *)c;
46     ld->costs.erase();
47     c += sizeof(size_t);
48     size_t total = sizeof(wclass)*num;
49     if (buf_read(cache, c, (int)total) < total)
50       {
51         cout << "error in demarshal of cost data" << endl;
52         return c;
53       }
54     for (size_t i = 0; i<num; i++)
55       {
56         wclass temp = *(wclass *)c;
57         c += sizeof(wclass);
58         ld->costs.push_back(temp);
59       }
60 
61     return c;
62   }
63 
read_cached_label(shared_data *,void * v,io_buf & cache)64   size_t read_cached_label(shared_data*, void* v, io_buf& cache)
65   {
66     label* ld = (label*) v;
67     ld->costs.erase();
68     char *c;
69     size_t total = sizeof(size_t);
70     if (buf_read(cache, c, (int)total) < total)
71       return 0;
72     c = bufread_label(ld,c, cache);
73 
74     return total;
75   }
76 
weight(void * v)77   float weight(void* v)
78   {
79     return 1.;
80   }
81 
bufcache_label(label * ld,char * c)82   char* bufcache_label(label* ld, char* c)
83   {
84     *(size_t *)c = ld->costs.size();
85     c += sizeof(size_t);
86     for (unsigned int i = 0; i< ld->costs.size(); i++)
87       {
88         *(wclass *)c = ld->costs[i];
89         c += sizeof(wclass);
90       }
91     return c;
92   }
93 
cache_label(void * v,io_buf & cache)94   void cache_label(void* v, io_buf& cache)
95   {
96     char *c;
97     label* ld = (label*) v;
98     buf_write(cache, c, sizeof(size_t)+sizeof(wclass)*ld->costs.size());
99     bufcache_label(ld,c);
100   }
101 
default_label(void * v)102   void default_label(void* v)
103   {
104     label* ld = (label*) v;
105     ld->costs.erase();
106   }
107 
delete_label(void * v)108   void delete_label(void* v)
109   {
110     label* ld = (label*)v;
111     if (ld) ld->costs.delete_v();
112   }
113 
copy_label(void * dst,void * src)114   void copy_label(void*dst, void*src)
115   {
116     if (dst && src) {
117       label* ldD = (label*)dst;
118       label* ldS = (label*)src;
119       copy_array(ldD->costs, ldS->costs);
120     }
121   }
122 
substring_eq(substring ss,const char * str)123   bool substring_eq(substring ss, const char* str) {
124     size_t len_ss  = ss.end - ss.begin;
125     size_t len_str = strlen(str);
126     if (len_ss != len_str) return false;
127     return (strncmp(ss.begin, str, len_ss) == 0);
128   }
129 
parse_label(parser * p,shared_data * sd,void * v,v_array<substring> & words)130   void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words)
131   {
132     label* ld = (label*)v;
133 
134     ld->costs.erase();
135     for (unsigned int i = 0; i < words.size(); i++) {
136       wclass f = {0.,0,0.,0.};
137       name_value(words[i], p->parse_name, f.x);
138 
139       if (p->parse_name.size() == 0)
140         cerr << "invalid cost: specification -- no names!" << endl;
141       else {
142         if (substring_eq(p->parse_name[0], "shared")) {
143           if (p->parse_name.size() == 1) {
144             f.class_index = 0;
145             f.x = -1.f;
146           } else
147             cerr << "shared feature vectors should not have costs" << endl;
148         } else if (substring_eq(p->parse_name[0], "label")) {
149           if (p->parse_name.size() == 2) {
150             f.class_index = 0;
151             // f.x is already set properly
152           } else
153             cerr << "label feature vectors must have label ids" << endl;
154         } else {
155           if (p->parse_name.size() == 1 || p->parse_name.size() == 2 || p->parse_name.size() == 3) {
156             f.class_index = (uint32_t)hashstring(p->parse_name[0], 0);
157             if (p->parse_name.size() == 1 && f.x >= 0)  // test examples are specified just by un-valued class #s
158               f.x = FLT_MAX;
159           } else
160             cerr << "malformed cost specification on '" << (p->parse_name[0].begin) << "'" << endl;
161         }
162         ld->costs.push_back(f);
163       }
164     }
165   }
166 
167   label_parser cs_label = {default_label, parse_label,
168 				  cache_label, read_cached_label,
169 				  delete_label, weight,
170 				  copy_label,
171 				  sizeof(label)};
172 
print_update(vw & all,bool is_test,example & ec,const v_array<example * > * ec_seq)173   void print_update(vw& all, bool is_test, example& ec, const v_array<example*>* ec_seq)
174   {
175     if (all.sd->weighted_examples >= all.sd->dump_interval && !all.quiet && !all.bfgs)
176       {
177         size_t num_current_features = ec.num_features;
178         // for csoaa_ldf we want features from the whole (multiline example),
179         // not only from one line (the first one) represented by ec
180         if (ec_seq != NULL)
181 	  {
182           num_current_features = 0;
183           // If the first example is "shared", don't include its features.
184           // These should be already included in each example (TODO: including quadratic and cubic).
185           // TODO: code duplication csoaa.cc LabelDict::ec_is_example_header
186           example** ecc=ec_seq->begin;
187           example& first_ex = **ecc;
188 
189 	  v_array<COST_SENSITIVE::wclass> costs = first_ex.l.cs.costs;
190 	  if (costs.size() == 1 && costs[0].class_index == 0 && costs[0].x < 0) ecc++;
191 
192           for (; ecc!=ec_seq->end; ecc++)
193 	    num_current_features += (*ecc)->num_features;
194         }
195 
196         char label_buf[32];
197         if (is_test)
198           strcpy(label_buf," unknown");
199         else
200           sprintf(label_buf," known");
201 	char pred_buf[32];
202 	sprintf(pred_buf,"%8lu",(long unsigned int)ec.pred.multiclass);
203 
204 	all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, pred_buf,
205 			     num_current_features, all.progress_add, all.progress_arg);
206       }
207   }
208 
output_example(vw & all,example & ec)209   void output_example(vw& all, example& ec)
210   {
211     label& ld = ec.l.cs;
212 
213     float loss = 0.;
214     if (!is_test_label(ld))
215       {//need to compute exact loss
216         size_t pred = (size_t)ec.pred.multiclass;
217 
218         float chosen_loss = FLT_MAX;
219         float min = FLT_MAX;
220         for (wclass *cl = ld.costs.begin; cl != ld.costs.end; cl ++) {
221           if (cl->class_index == pred)
222             chosen_loss = cl->x;
223           if (cl->x < min)
224             min = cl->x;
225         }
226         if (chosen_loss == FLT_MAX)
227           cerr << "warning: csoaa predicted an invalid class" << endl;
228 
229         loss = chosen_loss - min;
230       }
231 
232     all.sd->update(ec.test_only, loss, 1.f, ec.num_features);
233 
234     for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
235       all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag);
236 
237     if (all.raw_prediction > 0) {
238       stringstream outputStringStream;
239       for (unsigned int i = 0; i < ld.costs.size(); i++) {
240         wclass cl = ld.costs[i];
241         if (i > 0) outputStringStream << ' ';
242         outputStringStream << cl.class_index << ':' << cl.partial_prediction;
243       }
244       all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
245     }
246 
247     print_update(all, is_test_label(ec.l.cs), ec, NULL);
248   }
249 
example_is_test(example & ec)250   bool example_is_test(example& ec)
251   {
252     v_array<COST_SENSITIVE::wclass> costs = ec.l.cs.costs;
253     if (costs.size() == 0) return true;
254     for (size_t j=0; j<costs.size(); j++)
255       if (costs[j].x != FLT_MAX) return false;
256     return true;
257   }
258 }
259