1 #include <limits.h>
2 #include "global_data.h"
3 #include "vw.h"
4 
5 namespace MULTICLASS {
6 
bufread_label(label_t * ld,char * c)7   char* bufread_label(label_t* ld, char* c)
8   {
9     ld->label = *(uint32_t *)c;
10     c += sizeof(ld->label);
11     ld->weight = *(float *)c;
12     c += sizeof(ld->weight);
13     return c;
14   }
15 
read_cached_label(shared_data *,void * v,io_buf & cache)16   size_t read_cached_label(shared_data*, void* v, io_buf& cache)
17   {
18     label_t* ld = (label_t*) v;
19     char *c;
20     size_t total = sizeof(ld->label)+sizeof(ld->weight);
21     if (buf_read(cache, c, total) < total)
22       return 0;
23     c = bufread_label(ld,c);
24 
25     return total;
26   }
27 
weight(void * v)28   float weight(void* v)
29   {
30     label_t* ld = (label_t*) v;
31     return (ld->weight > 0) ? ld->weight : 0.f;
32   }
33 
bufcache_label(label_t * ld,char * c)34   char* bufcache_label(label_t* ld, char* c)
35   {
36     *(uint32_t *)c = ld->label;
37     c += sizeof(ld->label);
38     *(float *)c = ld->weight;
39     c += sizeof(ld->weight);
40     return c;
41   }
42 
cache_label(void * v,io_buf & cache)43   void cache_label(void* v, io_buf& cache)
44   {
45     char *c;
46     label_t* ld = (label_t*) v;
47     buf_write(cache, c, sizeof(ld->label)+sizeof(ld->weight));
48     c = bufcache_label(ld,c);
49   }
50 
default_label(void * v)51   void default_label(void* v)
52   {
53     label_t* ld = (label_t*) v;
54     ld->label = (uint32_t)-1;
55     ld->weight = 1.;
56   }
57 
delete_label(void * v)58   void delete_label(void* v) {}
59 
parse_label(parser * p,shared_data *,void * v,v_array<substring> & words)60   void parse_label(parser* p, shared_data*, void* v, v_array<substring>& words)
61   {
62     label_t* ld = (label_t*)v;
63 
64     switch(words.size()) {
65     case 0:
66       break;
67     case 1:
68       ld->label = int_of_substring(words[0]);
69       ld->weight = 1.0;
70       break;
71     case 2:
72       ld->label = int_of_substring(words[0]);
73       ld->weight = float_of_substring(words[1]);
74       break;
75     default:
76       cerr << "malformed example!\n";
77       cerr << "words.size() = " << words.size() << endl;
78     }
79     if (ld->label == 0)
80       {
81 	cout << "label 0 is not allowed for multiclass.  Valid labels are {1,k}" << endl;
82 	throw exception();
83       }
84   }
85 
86   label_parser mc_label = {default_label, parse_label,
87 				  cache_label, read_cached_label,
88 				  delete_label, weight,
89 				  NULL,
90 				  sizeof(label_t)};
91 
print_update(vw & all,example & ec)92   void print_update(vw& all, example &ec)
93   {
94     if (all.sd->weighted_examples >= all.sd->dump_interval && !all.quiet && !all.bfgs)
95       {
96         label_t ld = ec.l.multi;
97         char label_buf[32];
98         if (ld.label == INT_MAX)
99           strcpy(label_buf," unknown");
100         else
101           sprintf(label_buf,"%8ld",(long int)ld.label);
102 	char pred_buf[32];
103 	sprintf(pred_buf,"%8lu",(long unsigned int)ec.pred.multiclass);
104 
105 	all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, pred_buf,
106 			     ec.num_features, all.progress_add, all.progress_arg);
107       }
108   }
109 
finish_example(vw & all,example & ec)110   void finish_example(vw& all, example& ec)
111   {
112     float loss = 1;
113     if (ec.l.multi.label == (uint32_t)ec.pred.multiclass)
114       loss = 0;
115 
116     all.sd->update(ec.test_only, loss, ec.l.multi.weight, ec.num_features);
117 
118     for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
119       all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag);
120 
121     MULTICLASS::print_update(all, ec);
122     VW::finish_example(all, &ec);
123   }
124 }
125