1 #include <limits.h> 2 #include "global_data.h" 3 #include "vw.h" 4 5 namespace MULTICLASS { 6 bufread_label(label_t * ld,char * c)7 char* bufread_label(label_t* ld, char* c) 8 { 9 ld->label = *(uint32_t *)c; 10 c += sizeof(ld->label); 11 ld->weight = *(float *)c; 12 c += sizeof(ld->weight); 13 return c; 14 } 15 read_cached_label(shared_data *,void * v,io_buf & cache)16 size_t read_cached_label(shared_data*, void* v, io_buf& cache) 17 { 18 label_t* ld = (label_t*) v; 19 char *c; 20 size_t total = sizeof(ld->label)+sizeof(ld->weight); 21 if (buf_read(cache, c, total) < total) 22 return 0; 23 c = bufread_label(ld,c); 24 25 return total; 26 } 27 weight(void * v)28 float weight(void* v) 29 { 30 label_t* ld = (label_t*) v; 31 return (ld->weight > 0) ? ld->weight : 0.f; 32 } 33 bufcache_label(label_t * ld,char * c)34 char* bufcache_label(label_t* ld, char* c) 35 { 36 *(uint32_t *)c = ld->label; 37 c += sizeof(ld->label); 38 *(float *)c = ld->weight; 39 c += sizeof(ld->weight); 40 return c; 41 } 42 cache_label(void * v,io_buf & cache)43 void cache_label(void* v, io_buf& cache) 44 { 45 char *c; 46 label_t* ld = (label_t*) v; 47 buf_write(cache, c, sizeof(ld->label)+sizeof(ld->weight)); 48 c = bufcache_label(ld,c); 49 } 50 default_label(void * v)51 void default_label(void* v) 52 { 53 label_t* ld = (label_t*) v; 54 ld->label = (uint32_t)-1; 55 ld->weight = 1.; 56 } 57 delete_label(void * v)58 void delete_label(void* v) {} 59 parse_label(parser * p,shared_data *,void * v,v_array<substring> & words)60 void parse_label(parser* p, shared_data*, void* v, v_array<substring>& words) 61 { 62 label_t* ld = (label_t*)v; 63 64 switch(words.size()) { 65 case 0: 66 break; 67 case 1: 68 ld->label = int_of_substring(words[0]); 69 ld->weight = 1.0; 70 break; 71 case 2: 72 ld->label = int_of_substring(words[0]); 73 ld->weight = float_of_substring(words[1]); 74 break; 75 default: 76 cerr << "malformed example!\n"; 77 cerr << "words.size() = " << words.size() << endl; 78 } 79 if (ld->label == 0) 80 { 81 cout << "label 0 is not allowed for multiclass. Valid labels are {1,k}" << endl; 82 throw exception(); 83 } 84 } 85 86 label_parser mc_label = {default_label, parse_label, 87 cache_label, read_cached_label, 88 delete_label, weight, 89 NULL, 90 sizeof(label_t)}; 91 print_update(vw & all,example & ec)92 void print_update(vw& all, example &ec) 93 { 94 if (all.sd->weighted_examples >= all.sd->dump_interval && !all.quiet && !all.bfgs) 95 { 96 label_t ld = ec.l.multi; 97 char label_buf[32]; 98 if (ld.label == INT_MAX) 99 strcpy(label_buf," unknown"); 100 else 101 sprintf(label_buf,"%8ld",(long int)ld.label); 102 char pred_buf[32]; 103 sprintf(pred_buf,"%8lu",(long unsigned int)ec.pred.multiclass); 104 105 all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, pred_buf, 106 ec.num_features, all.progress_add, all.progress_arg); 107 } 108 } 109 finish_example(vw & all,example & ec)110 void finish_example(vw& all, example& ec) 111 { 112 float loss = 1; 113 if (ec.l.multi.label == (uint32_t)ec.pred.multiclass) 114 loss = 0; 115 116 all.sd->update(ec.test_only, loss, ec.l.multi.weight, ec.num_features); 117 118 for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) 119 all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag); 120 121 MULTICLASS::print_update(all, ec); 122 VW::finish_example(all, &ec); 123 } 124 } 125