1 #include "float.h" 2 #include "gd.h" 3 #include "vw.h" 4 5 namespace COST_SENSITIVE { 6 name_value(substring & s,v_array<substring> & name,float & v)7 void name_value(substring &s, v_array<substring>& name, float &v) 8 { 9 tokenize(':', s, name); 10 11 switch (name.size()) { 12 case 0: 13 case 1: 14 v = 1.; 15 break; 16 case 2: 17 v = float_of_substring(name[1]); 18 if ( nanpattern(v)) 19 { 20 cerr << "error NaN value for: "; 21 cerr.write(name[0].begin, name[0].end - name[0].begin); 22 cerr << " terminating." << endl; 23 throw exception(); 24 } 25 break; 26 default: 27 cerr << "example with a wierd name. What is '"; 28 cerr.write(s.begin, s.end - s.begin); 29 cerr << "'?\n"; 30 } 31 } 32 is_test_label(label & ld)33 bool is_test_label(label& ld) 34 { 35 if (ld.costs.size() == 0) 36 return true; 37 for (unsigned int i=0; i<ld.costs.size(); i++) 38 if (FLT_MAX != ld.costs[i].x) 39 return false; 40 return true; 41 } 42 bufread_label(label * ld,char * c,io_buf & cache)43 char* bufread_label(label* ld, char* c, io_buf& cache) 44 { 45 size_t num = *(size_t *)c; 46 ld->costs.erase(); 47 c += sizeof(size_t); 48 size_t total = sizeof(wclass)*num; 49 if (buf_read(cache, c, (int)total) < total) 50 { 51 cout << "error in demarshal of cost data" << endl; 52 return c; 53 } 54 for (size_t i = 0; i<num; i++) 55 { 56 wclass temp = *(wclass *)c; 57 c += sizeof(wclass); 58 ld->costs.push_back(temp); 59 } 60 61 return c; 62 } 63 read_cached_label(shared_data *,void * v,io_buf & cache)64 size_t read_cached_label(shared_data*, void* v, io_buf& cache) 65 { 66 label* ld = (label*) v; 67 ld->costs.erase(); 68 char *c; 69 size_t total = sizeof(size_t); 70 if (buf_read(cache, c, (int)total) < total) 71 return 0; 72 c = bufread_label(ld,c, cache); 73 74 return total; 75 } 76 weight(void * v)77 float weight(void* v) 78 { 79 return 1.; 80 } 81 bufcache_label(label * ld,char * c)82 char* bufcache_label(label* ld, char* c) 83 { 84 *(size_t *)c = ld->costs.size(); 85 c += sizeof(size_t); 86 for (unsigned int i = 0; i< ld->costs.size(); i++) 87 { 88 *(wclass *)c = ld->costs[i]; 89 c += sizeof(wclass); 90 } 91 return c; 92 } 93 cache_label(void * v,io_buf & cache)94 void cache_label(void* v, io_buf& cache) 95 { 96 char *c; 97 label* ld = (label*) v; 98 buf_write(cache, c, sizeof(size_t)+sizeof(wclass)*ld->costs.size()); 99 bufcache_label(ld,c); 100 } 101 default_label(void * v)102 void default_label(void* v) 103 { 104 label* ld = (label*) v; 105 ld->costs.erase(); 106 } 107 delete_label(void * v)108 void delete_label(void* v) 109 { 110 label* ld = (label*)v; 111 if (ld) ld->costs.delete_v(); 112 } 113 copy_label(void * dst,void * src)114 void copy_label(void*dst, void*src) 115 { 116 if (dst && src) { 117 label* ldD = (label*)dst; 118 label* ldS = (label*)src; 119 copy_array(ldD->costs, ldS->costs); 120 } 121 } 122 substring_eq(substring ss,const char * str)123 bool substring_eq(substring ss, const char* str) { 124 size_t len_ss = ss.end - ss.begin; 125 size_t len_str = strlen(str); 126 if (len_ss != len_str) return false; 127 return (strncmp(ss.begin, str, len_ss) == 0); 128 } 129 parse_label(parser * p,shared_data * sd,void * v,v_array<substring> & words)130 void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words) 131 { 132 label* ld = (label*)v; 133 134 ld->costs.erase(); 135 for (unsigned int i = 0; i < words.size(); i++) { 136 wclass f = {0.,0,0.,0.}; 137 name_value(words[i], p->parse_name, f.x); 138 139 if (p->parse_name.size() == 0) 140 cerr << "invalid cost: specification -- no names!" << endl; 141 else { 142 if (substring_eq(p->parse_name[0], "shared")) { 143 if (p->parse_name.size() == 1) { 144 f.class_index = 0; 145 f.x = -1.f; 146 } else 147 cerr << "shared feature vectors should not have costs" << endl; 148 } else if (substring_eq(p->parse_name[0], "label")) { 149 if (p->parse_name.size() == 2) { 150 f.class_index = 0; 151 // f.x is already set properly 152 } else 153 cerr << "label feature vectors must have label ids" << endl; 154 } else { 155 if (p->parse_name.size() == 1 || p->parse_name.size() == 2 || p->parse_name.size() == 3) { 156 f.class_index = (uint32_t)hashstring(p->parse_name[0], 0); 157 if (p->parse_name.size() == 1 && f.x >= 0) // test examples are specified just by un-valued class #s 158 f.x = FLT_MAX; 159 } else 160 cerr << "malformed cost specification on '" << (p->parse_name[0].begin) << "'" << endl; 161 } 162 ld->costs.push_back(f); 163 } 164 } 165 } 166 167 label_parser cs_label = {default_label, parse_label, 168 cache_label, read_cached_label, 169 delete_label, weight, 170 copy_label, 171 sizeof(label)}; 172 print_update(vw & all,bool is_test,example & ec,const v_array<example * > * ec_seq)173 void print_update(vw& all, bool is_test, example& ec, const v_array<example*>* ec_seq) 174 { 175 if (all.sd->weighted_examples >= all.sd->dump_interval && !all.quiet && !all.bfgs) 176 { 177 size_t num_current_features = ec.num_features; 178 // for csoaa_ldf we want features from the whole (multiline example), 179 // not only from one line (the first one) represented by ec 180 if (ec_seq != NULL) 181 { 182 num_current_features = 0; 183 // If the first example is "shared", don't include its features. 184 // These should be already included in each example (TODO: including quadratic and cubic). 185 // TODO: code duplication csoaa.cc LabelDict::ec_is_example_header 186 example** ecc=ec_seq->begin; 187 example& first_ex = **ecc; 188 189 v_array<COST_SENSITIVE::wclass> costs = first_ex.l.cs.costs; 190 if (costs.size() == 1 && costs[0].class_index == 0 && costs[0].x < 0) ecc++; 191 192 for (; ecc!=ec_seq->end; ecc++) 193 num_current_features += (*ecc)->num_features; 194 } 195 196 char label_buf[32]; 197 if (is_test) 198 strcpy(label_buf," unknown"); 199 else 200 sprintf(label_buf," known"); 201 char pred_buf[32]; 202 sprintf(pred_buf,"%8lu",(long unsigned int)ec.pred.multiclass); 203 204 all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, pred_buf, 205 num_current_features, all.progress_add, all.progress_arg); 206 } 207 } 208 output_example(vw & all,example & ec)209 void output_example(vw& all, example& ec) 210 { 211 label& ld = ec.l.cs; 212 213 float loss = 0.; 214 if (!is_test_label(ld)) 215 {//need to compute exact loss 216 size_t pred = (size_t)ec.pred.multiclass; 217 218 float chosen_loss = FLT_MAX; 219 float min = FLT_MAX; 220 for (wclass *cl = ld.costs.begin; cl != ld.costs.end; cl ++) { 221 if (cl->class_index == pred) 222 chosen_loss = cl->x; 223 if (cl->x < min) 224 min = cl->x; 225 } 226 if (chosen_loss == FLT_MAX) 227 cerr << "warning: csoaa predicted an invalid class" << endl; 228 229 loss = chosen_loss - min; 230 } 231 232 all.sd->update(ec.test_only, loss, 1.f, ec.num_features); 233 234 for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) 235 all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag); 236 237 if (all.raw_prediction > 0) { 238 stringstream outputStringStream; 239 for (unsigned int i = 0; i < ld.costs.size(); i++) { 240 wclass cl = ld.costs[i]; 241 if (i > 0) outputStringStream << ' '; 242 outputStringStream << cl.class_index << ':' << cl.partial_prediction; 243 } 244 all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag); 245 } 246 247 print_update(all, is_test_label(ec.l.cs), ec, NULL); 248 } 249 example_is_test(example & ec)250 bool example_is_test(example& ec) 251 { 252 v_array<COST_SENSITIVE::wclass> costs = ec.l.cs.costs; 253 if (costs.size() == 0) return true; 254 for (size_t j=0; j<costs.size(); j++) 255 if (costs[j].x != FLT_MAX) return false; 256 return true; 257 } 258 } 259