1 /* 2 Copyright (c) by respective owners including Yahoo!, Microsoft, and 3 individual contributors. All rights reserved. Released under a BSD (revised) 4 license as described in the file LICENSE. 5 */ 6 #include <float.h> 7 8 #include "example.h" 9 #include "parse_primitives.h" 10 #include "vw.h" 11 12 using namespace LEARNER; 13 14 namespace CB 15 { bufread_label(CB::label * ld,char * c,io_buf & cache)16 char* bufread_label(CB::label* ld, char* c, io_buf& cache) 17 { 18 size_t num = *(size_t *)c; 19 ld->costs.erase(); 20 c += sizeof(size_t); 21 size_t total = sizeof(cb_class)*num; 22 if (buf_read(cache, c, total) < total) 23 { 24 cout << "error in demarshal of cost data" << endl; 25 return c; 26 } 27 for (size_t i = 0; i<num; i++) 28 { 29 cb_class temp = *(cb_class *)c; 30 c += sizeof(cb_class); 31 ld->costs.push_back(temp); 32 } 33 34 return c; 35 } 36 read_cached_label(shared_data *,void * v,io_buf & cache)37 size_t read_cached_label(shared_data*, void* v, io_buf& cache) 38 { 39 CB::label* ld = (CB::label*) v; 40 ld->costs.erase(); 41 char *c; 42 size_t total = sizeof(size_t); 43 if (buf_read(cache, c, total) < total) 44 return 0; 45 c = bufread_label(ld,c, cache); 46 47 return total; 48 } 49 weight(void * v)50 float weight(void* v) 51 { 52 return 1.; 53 } 54 bufcache_label(CB::label * ld,char * c)55 char* bufcache_label(CB::label* ld, char* c) 56 { 57 *(size_t *)c = ld->costs.size(); 58 c += sizeof(size_t); 59 for (size_t i = 0; i< ld->costs.size(); i++) 60 { 61 *(cb_class *)c = ld->costs[i]; 62 c += sizeof(cb_class); 63 } 64 return c; 65 } 66 cache_label(void * v,io_buf & cache)67 void cache_label(void* v, io_buf& cache) 68 { 69 char *c; 70 CB::label* ld = (CB::label*) v; 71 buf_write(cache, c, sizeof(size_t)+sizeof(cb_class)*ld->costs.size()); 72 bufcache_label(ld,c); 73 } 74 default_label(void * v)75 void default_label(void* v) 76 { 77 CB::label* ld = (CB::label*) v; 78 ld->costs.erase(); 79 } 80 delete_label(void * v)81 void delete_label(void* v) 82 { 83 CB::label* ld = (CB::label*)v; 84 ld->costs.delete_v(); 85 } 86 copy_label(void * dst,void * src)87 void copy_label(void*dst, void*src) 88 { 89 CB::label* ldD = (CB::label*)dst; 90 CB::label* ldS = (CB::label*)src; 91 copy_array(ldD->costs, ldS->costs); 92 } 93 parse_label(parser * p,shared_data * sd,void * v,v_array<substring> & words)94 void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words) 95 { 96 CB::label* ld = (CB::label*)v; 97 98 for (size_t i = 0; i < words.size(); i++) 99 { 100 cb_class f; 101 tokenize(':', words[i], p->parse_name); 102 103 if( p->parse_name.size() < 1 || p->parse_name.size() > 3 ) 104 { 105 cerr << "malformed cost specification!" << endl; 106 cerr << "terminating." << endl; 107 throw exception(); 108 } 109 110 f.partial_prediction = 0.; 111 f.action = (uint32_t)hashstring(p->parse_name[0], 0); 112 f.cost = FLT_MAX; 113 if(p->parse_name.size() > 1) 114 f.cost = float_of_substring(p->parse_name[1]); 115 116 if ( nanpattern(f.cost)) 117 { 118 cerr << "error NaN cost for action: "; 119 cerr.write(p->parse_name[0].begin, p->parse_name[0].end - p->parse_name[0].begin); 120 cerr << " terminating." << endl; 121 throw exception(); 122 } 123 124 f.probability = .0; 125 if(p->parse_name.size() > 2) 126 f.probability = float_of_substring(p->parse_name[2]); 127 128 if ( nanpattern(f.probability)) 129 { 130 cerr << "error NaN probability for action: "; 131 cerr.write(p->parse_name[0].begin, p->parse_name[0].end - p->parse_name[0].begin); 132 cerr << " terminating." << endl; 133 throw exception(); 134 } 135 136 if( f.probability > 1.0 ) 137 { 138 cerr << "invalid probability > 1 specified for an action, resetting to 1." << endl; 139 f.probability = 1.0; 140 } 141 if( f.probability < 0.0 ) 142 { 143 cerr << "invalid probability < 0 specified for an action, resetting to 0." << endl; 144 f.probability = .0; 145 } 146 147 ld->costs.push_back(f); 148 } 149 } 150 151 label_parser cb_label = {default_label, parse_label, 152 cache_label, read_cached_label, 153 delete_label, weight, 154 copy_label, 155 sizeof(label)}; 156 157 } 158 159 namespace CB_EVAL 160 { read_cached_label(shared_data * sd,void * v,io_buf & cache)161 size_t read_cached_label(shared_data*sd, void* v, io_buf& cache) 162 { 163 CB_EVAL::label* ld = (CB_EVAL::label*) v; 164 char* c; 165 size_t total = sizeof(uint32_t); 166 if (buf_read(cache, c, total) < total) 167 return 0; 168 ld->action = *(uint32_t*)c; 169 c += sizeof(uint32_t); 170 171 return total + CB::read_cached_label(sd, &(ld->event), cache); 172 } 173 cache_label(void * v,io_buf & cache)174 void cache_label(void* v, io_buf& cache) 175 { 176 char *c; 177 CB_EVAL::label* ld = (CB_EVAL::label*) v; 178 buf_write(cache, c, sizeof(uint32_t)); 179 *(uint32_t *)c = ld->action; 180 c+= sizeof(uint32_t); 181 182 CB::cache_label(&(ld->event), cache); 183 } 184 default_label(void * v)185 void default_label(void* v) 186 { 187 CB_EVAL::label* ld = (CB_EVAL::label*) v; 188 CB::default_label(&(ld->event)); 189 ld->action = 0; 190 } 191 delete_label(void * v)192 void delete_label(void* v) 193 { 194 CB_EVAL::label* ld = (CB_EVAL::label*)v; 195 CB::delete_label(&(ld->event)); 196 } 197 copy_label(void * dst,void * src)198 void copy_label(void*dst, void*src) 199 { 200 CB_EVAL::label* ldD = (CB_EVAL::label*)dst; 201 CB_EVAL::label* ldS = (CB_EVAL::label*)src; 202 CB::copy_label(&(ldD->event), &(ldS)->event); 203 ldD->action = ldS->action; 204 } 205 parse_label(parser * p,shared_data * sd,void * v,v_array<substring> & words)206 void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words) 207 { 208 CB_EVAL::label* ld = (CB_EVAL::label*)v; 209 210 if (words.size() < 2) 211 { 212 cout << "Evaluation can not happen without an action and an exploration" << endl; 213 throw exception(); 214 } 215 216 ld->action = (uint32_t)hashstring(words[0], 0); 217 218 words.begin++; 219 220 CB::parse_label(p, sd, &(ld->event), words); 221 222 words.begin--; 223 } 224 225 label_parser cb_eval = {default_label, parse_label, 226 cache_label, read_cached_label, 227 delete_label, CB::weight, 228 copy_label, 229 sizeof(CB_EVAL::label)}; 230 } 231