1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <float.h>
7 
8 #include "example.h"
9 #include "parse_primitives.h"
10 #include "vw.h"
11 
12 using namespace LEARNER;
13 
14 namespace CB
15 {
bufread_label(CB::label * ld,char * c,io_buf & cache)16   char* bufread_label(CB::label* ld, char* c, io_buf& cache)
17   {
18     size_t num = *(size_t *)c;
19     ld->costs.erase();
20     c += sizeof(size_t);
21     size_t total = sizeof(cb_class)*num;
22     if (buf_read(cache, c, total) < total)
23       {
24         cout << "error in demarshal of cost data" << endl;
25         return c;
26       }
27     for (size_t i = 0; i<num; i++)
28       {
29         cb_class temp = *(cb_class *)c;
30         c += sizeof(cb_class);
31         ld->costs.push_back(temp);
32       }
33 
34     return c;
35   }
36 
read_cached_label(shared_data *,void * v,io_buf & cache)37   size_t read_cached_label(shared_data*, void* v, io_buf& cache)
38   {
39     CB::label* ld = (CB::label*) v;
40     ld->costs.erase();
41     char *c;
42     size_t total = sizeof(size_t);
43     if (buf_read(cache, c, total) < total)
44       return 0;
45     c = bufread_label(ld,c, cache);
46 
47     return total;
48   }
49 
weight(void * v)50   float weight(void* v)
51   {
52     return 1.;
53   }
54 
bufcache_label(CB::label * ld,char * c)55   char* bufcache_label(CB::label* ld, char* c)
56   {
57     *(size_t *)c = ld->costs.size();
58     c += sizeof(size_t);
59     for (size_t i = 0; i< ld->costs.size(); i++)
60       {
61         *(cb_class *)c = ld->costs[i];
62         c += sizeof(cb_class);
63       }
64     return c;
65   }
66 
cache_label(void * v,io_buf & cache)67   void cache_label(void* v, io_buf& cache)
68   {
69     char *c;
70     CB::label* ld = (CB::label*) v;
71     buf_write(cache, c, sizeof(size_t)+sizeof(cb_class)*ld->costs.size());
72     bufcache_label(ld,c);
73   }
74 
default_label(void * v)75   void default_label(void* v)
76   {
77     CB::label* ld = (CB::label*) v;
78     ld->costs.erase();
79   }
80 
delete_label(void * v)81   void delete_label(void* v)
82   {
83     CB::label* ld = (CB::label*)v;
84     ld->costs.delete_v();
85   }
86 
copy_label(void * dst,void * src)87   void copy_label(void*dst, void*src)
88   {
89     CB::label* ldD = (CB::label*)dst;
90     CB::label* ldS = (CB::label*)src;
91     copy_array(ldD->costs, ldS->costs);
92   }
93 
parse_label(parser * p,shared_data * sd,void * v,v_array<substring> & words)94   void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words)
95   {
96     CB::label* ld = (CB::label*)v;
97 
98     for (size_t i = 0; i < words.size(); i++)
99       {
100         cb_class f;
101 	tokenize(':', words[i], p->parse_name);
102 
103         if( p->parse_name.size() < 1 || p->parse_name.size() > 3 )
104         {
105           cerr << "malformed cost specification!" << endl;
106 	  cerr << "terminating." << endl;
107           throw exception();
108         }
109 
110         f.partial_prediction = 0.;
111         f.action = (uint32_t)hashstring(p->parse_name[0], 0);
112         f.cost = FLT_MAX;
113         if(p->parse_name.size() > 1)
114           f.cost = float_of_substring(p->parse_name[1]);
115 
116         if ( nanpattern(f.cost))
117         {
118 	  cerr << "error NaN cost for action: ";
119 	  cerr.write(p->parse_name[0].begin, p->parse_name[0].end - p->parse_name[0].begin);
120 	  cerr << " terminating." << endl;
121 	  throw exception();
122         }
123 
124         f.probability = .0;
125         if(p->parse_name.size() > 2)
126           f.probability = float_of_substring(p->parse_name[2]);
127 
128         if ( nanpattern(f.probability))
129         {
130 	  cerr << "error NaN probability for action: ";
131 	  cerr.write(p->parse_name[0].begin, p->parse_name[0].end - p->parse_name[0].begin);
132 	  cerr << " terminating." << endl;
133 	  throw exception();
134         }
135 
136         if( f.probability > 1.0 )
137         {
138           cerr << "invalid probability > 1 specified for an action, resetting to 1." << endl;
139           f.probability = 1.0;
140         }
141         if( f.probability < 0.0 )
142         {
143           cerr << "invalid probability < 0 specified for an action, resetting to 0." << endl;
144           f.probability = .0;
145         }
146 
147         ld->costs.push_back(f);
148       }
149   }
150 
151   label_parser cb_label = {default_label, parse_label,
152 				  cache_label, read_cached_label,
153 				  delete_label, weight,
154 				  copy_label,
155 				  sizeof(label)};
156 
157 }
158 
159 namespace CB_EVAL
160 {
read_cached_label(shared_data * sd,void * v,io_buf & cache)161   size_t read_cached_label(shared_data*sd, void* v, io_buf& cache)
162   {
163     CB_EVAL::label* ld = (CB_EVAL::label*) v;
164     char* c;
165     size_t total = sizeof(uint32_t);
166     if (buf_read(cache, c, total) < total)
167       return 0;
168     ld->action = *(uint32_t*)c;
169     c += sizeof(uint32_t);
170 
171     return total + CB::read_cached_label(sd, &(ld->event), cache);
172   }
173 
cache_label(void * v,io_buf & cache)174   void cache_label(void* v, io_buf& cache)
175   {
176     char *c;
177     CB_EVAL::label* ld = (CB_EVAL::label*) v;
178     buf_write(cache, c, sizeof(uint32_t));
179     *(uint32_t *)c = ld->action;
180     c+= sizeof(uint32_t);
181 
182     CB::cache_label(&(ld->event), cache);
183   }
184 
default_label(void * v)185   void default_label(void* v)
186   {
187     CB_EVAL::label* ld = (CB_EVAL::label*) v;
188     CB::default_label(&(ld->event));
189     ld->action = 0;
190   }
191 
delete_label(void * v)192   void delete_label(void* v)
193   {
194     CB_EVAL::label* ld = (CB_EVAL::label*)v;
195     CB::delete_label(&(ld->event));
196   }
197 
copy_label(void * dst,void * src)198   void copy_label(void*dst, void*src)
199   {
200     CB_EVAL::label* ldD = (CB_EVAL::label*)dst;
201     CB_EVAL::label* ldS = (CB_EVAL::label*)src;
202     CB::copy_label(&(ldD->event), &(ldS)->event);
203     ldD->action = ldS->action;
204   }
205 
parse_label(parser * p,shared_data * sd,void * v,v_array<substring> & words)206   void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words)
207   {
208     CB_EVAL::label* ld = (CB_EVAL::label*)v;
209 
210     if (words.size() < 2)
211       {
212 	cout << "Evaluation can not happen without an action and an exploration" << endl;
213 	throw exception();
214       }
215 
216     ld->action = (uint32_t)hashstring(words[0], 0);
217 
218     words.begin++;
219 
220     CB::parse_label(p, sd, &(ld->event), words);
221 
222     words.begin--;
223   }
224 
225   label_parser cb_eval = {default_label, parse_label,
226 			  cache_label, read_cached_label,
227 			  delete_label, CB::weight,
228 			  copy_label,
229 			  sizeof(CB_EVAL::label)};
230 }
231