1 #include "../vowpalwabbit/vw.h"
2 #include "../vowpalwabbit/multiclass.h"
3 #include "../vowpalwabbit/cost_sensitive.h"
4 #include "../vowpalwabbit/cb.h"
5 #include "../vowpalwabbit/search.h"
6 #include "../vowpalwabbit/search_hooktask.h"
7 #include "../vowpalwabbit/parse_example.h"
8 #include "../vowpalwabbit/gd.h"
9
10 #include <boost/make_shared.hpp>
11 #include <boost/python.hpp>
12 #include <boost/python/suite/indexing/vector_indexing_suite.hpp>
13
14 using namespace std;
15 namespace py=boost::python;
16
17 typedef boost::shared_ptr<vw> vw_ptr;
18 typedef boost::shared_ptr<example> example_ptr;
19 typedef boost::shared_ptr<Search::search> search_ptr;
20 typedef boost::shared_ptr<Search::predictor> predictor_ptr;
21
22 const size_t lDEFAULT = 0;
23 const size_t lBINARY = 1;
24 const size_t lMULTICLASS = 2;
25 const size_t lCOST_SENSITIVE = 3;
26 const size_t lCONTEXTUAL_BANDIT = 4;
27 const size_t lMAX = 5;
28
29
dont_delete_me(void * arg)30 void dont_delete_me(void*arg) { }
31
my_initialize(string args)32 vw_ptr my_initialize(string args) {
33 vw*foo = VW::initialize(args);
34 return boost::shared_ptr<vw>(foo, dont_delete_me);
35 }
36
my_finish(vw_ptr all)37 void my_finish(vw_ptr all) {
38 VW::finish(*all, false); // don't delete all because python will do that for us!
39 }
40
get_search_ptr(vw_ptr all)41 search_ptr get_search_ptr(vw_ptr all) {
42 return boost::shared_ptr<Search::search>((Search::search*)(all->searchstr), dont_delete_me);
43 }
44
my_audit_example(vw_ptr all,example_ptr ec)45 void my_audit_example(vw_ptr all, example_ptr ec) { GD::print_audit_features(*all, *ec); }
46
get_predictor(search_ptr sch,ptag my_tag)47 predictor_ptr get_predictor(search_ptr sch, ptag my_tag) {
48 Search::predictor* P = new Search::predictor(*sch, my_tag);
49 return boost::shared_ptr<Search::predictor>(P);
50 }
51
get_label_parser(vw * all,size_t labelType)52 label_parser* get_label_parser(vw*all, size_t labelType) {
53 switch (labelType) {
54 case lDEFAULT: return all ? &all->p->lp : NULL;
55 case lBINARY: return &simple_label;
56 case lMULTICLASS: return &MULTICLASS::mc_label;
57 case lCOST_SENSITIVE: return &COST_SENSITIVE::cs_label;
58 case lCONTEXTUAL_BANDIT: return &CB::cb_label;
59 default: cerr << "get_label_parser called on invalid label type" << endl; throw exception();
60 }
61 }
62
my_delete_example(void * voidec)63 void my_delete_example(void*voidec) {
64 example* ec = (example*) voidec;
65 size_t labelType = (ec->tag.size() == 0) ? lDEFAULT : ec->tag[0];
66 label_parser* lp = get_label_parser(NULL, labelType);
67 dealloc_example(lp ? lp->delete_label : NULL, *ec);
68 free(ec);
69 }
70
71
my_empty_example0(vw_ptr vw,size_t labelType)72 example* my_empty_example0(vw_ptr vw, size_t labelType) {
73 label_parser* lp = get_label_parser(&*vw, labelType);
74 example* ec = alloc_examples(lp->label_size, 1);
75 lp->default_label(&ec->l);
76 if (labelType == lCOST_SENSITIVE) {
77 COST_SENSITIVE::wclass zero = { 0., 1, 0., 0. };
78 ec->l.cs.costs.push_back(zero);
79 }
80 ec->tag.erase();
81 if (labelType != lDEFAULT)
82 ec->tag.push_back((char)labelType); // hide the label type in the tag
83 return ec;
84 }
85
my_empty_example(vw_ptr vw,size_t labelType)86 example_ptr my_empty_example(vw_ptr vw, size_t labelType) {
87 example* ec = my_empty_example0(vw, labelType);
88 return boost::shared_ptr<example>(ec, my_delete_example);
89 }
90
my_read_example(vw_ptr all,size_t labelType,char * str)91 example_ptr my_read_example(vw_ptr all, size_t labelType, char*str) {
92 example*ec = my_empty_example0(all, labelType);
93 read_line(*all, ec, str);
94 parse_atomic_example(*all, ec, false);
95 VW::setup_example(*all, ec);
96 ec->example_counter = labelType;
97 ec->tag.erase();
98 if (labelType != lDEFAULT)
99 ec->tag.push_back((char)labelType); // hide the label type in the tag
100 return boost::shared_ptr<example>(ec, my_delete_example);
101 }
102
my_finish_example(vw_ptr all,example_ptr ec)103 void my_finish_example(vw_ptr all, example_ptr ec) {
104 // TODO
105 }
106
my_learn(vw_ptr all,example_ptr ec)107 void my_learn(vw_ptr all, example_ptr ec) {
108 all->learn(ec.get());
109 }
110
my_learn_string(vw_ptr all,char * str)111 float my_learn_string(vw_ptr all, char*str) {
112 example*ec = VW::read_example(*all, str);
113 all->learn(ec);
114 float pp = ec->partial_prediction;
115 VW::finish_example(*all, ec);
116 return pp;
117 }
118
varray_char_to_string(v_array<char> & a)119 string varray_char_to_string(v_array<char> &a) {
120 string ret = "";
121 for (char*c = a.begin; c != a.end; ++c)
122 ret += *c;
123 return ret;
124 }
125
my_get_tag(example_ptr ec)126 string my_get_tag(example_ptr ec) {
127 return varray_char_to_string(ec->tag);
128 }
129
ex_num_namespaces(example_ptr ec)130 uint32_t ex_num_namespaces(example_ptr ec) {
131 return ec->indices.size();
132 }
133
ex_namespace(example_ptr ec,uint32_t ns)134 unsigned char ex_namespace(example_ptr ec, uint32_t ns) {
135 return ec->indices[ns];
136 }
137
ex_num_features(example_ptr ec,unsigned char ns)138 uint32_t ex_num_features(example_ptr ec, unsigned char ns) {
139 return ec->atomics[ns].size();
140 }
141
ex_feature(example_ptr ec,unsigned char ns,uint32_t i)142 uint32_t ex_feature(example_ptr ec, unsigned char ns, uint32_t i) {
143 return ec->atomics[ns][i].weight_index;
144 }
145
ex_feature_weight(example_ptr ec,unsigned char ns,uint32_t i)146 float ex_feature_weight(example_ptr ec, unsigned char ns, uint32_t i) {
147 return ec->atomics[ns][i].x;
148 }
149
ex_sum_feat_sq(example_ptr ec,unsigned char ns)150 float ex_sum_feat_sq(example_ptr ec, unsigned char ns) {
151 return ec->sum_feat_sq[ns];
152 }
153
ex_push_feature(example_ptr ec,unsigned char ns,uint32_t fid,float v)154 void ex_push_feature(example_ptr ec, unsigned char ns, uint32_t fid, float v) {
155 // warning: assumes namespace exists!
156 feature f = { v, fid };
157 ec->atomics[ns].push_back(f);
158 ec->num_features++;
159 ec->sum_feat_sq[ns] += v * v;
160 ec->total_sum_feat_sq += v * v;
161 }
162
ex_push_feature_list(example_ptr ec,vw_ptr vw,unsigned char ns,py::list & a)163 void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns, py::list& a) {
164 // warning: assumes namespace exists!
165 char ns_str[2] = { (char)ns, 0 };
166 uint32_t ns_hash = VW::hash_space(*vw, ns_str);
167 size_t count = 0; float sum_sq = 0.;
168 for (size_t i=0; i<len(a); i++) {
169 feature f = { 1., 0 };
170 py::object ai = a[i];
171 py::extract<py::tuple> get_tup(ai);
172 if (get_tup.check()) {
173 py::tuple fv = get_tup();
174 if (len(fv) != 2) { cerr << "warning: malformed feature in list" << endl; continue; } // TODO str(ai)
175 py::extract<float> get_val(fv[1]);
176 if (get_val.check())
177 f.x = get_val();
178 else { cerr << "warning: malformed feature in list" << endl; continue; }
179 ai = fv[0];
180 }
181
182 bool got = false;
183 py::extract<uint32_t> get_int(ai);
184 if (get_int.check()) { f.weight_index = get_int(); got = true; }
185 else {
186 py::extract<string> get_str(ai);
187 if (get_str.check()) {
188 f.weight_index = VW::hash_feature(*vw, get_str(), ns_hash);
189 got = true;
190 } else { cerr << "warning: malformed feature in list" << endl; continue; }
191 }
192 if (got && (f.x != 0.)) {
193 ec->atomics[ns].push_back(f);
194 count++;
195 sum_sq += f.x * f.x;
196 }
197 }
198 ec->num_features += count;
199 ec->sum_feat_sq[ns] += sum_sq;
200 ec->total_sum_feat_sq += sum_sq;
201 }
202
ex_pop_feature(example_ptr ec,unsigned char ns)203 bool ex_pop_feature(example_ptr ec, unsigned char ns) {
204 if (ec->atomics[ns].size() == 0) return false;
205 feature f = ec->atomics[ns].pop();
206 ec->num_features--;
207 ec->sum_feat_sq[ns] -= f.x * f.x;
208 ec->total_sum_feat_sq -= f.x * f.x;
209 return true;
210 }
211
ex_push_namespace(example_ptr ec,unsigned char ns)212 void ex_push_namespace(example_ptr ec, unsigned char ns) {
213 ec->indices.push_back(ns);
214 }
215
ex_ensure_namespace_exists(example_ptr ec,unsigned char ns)216 void ex_ensure_namespace_exists(example_ptr ec, unsigned char ns) {
217 for (unsigned char* nss = ec->indices.begin; nss != ec->indices.end; ++nss)
218 if (ns == *nss) return;
219 ex_push_namespace(ec, ns);
220 }
221
ex_pop_namespace(example_ptr ec)222 bool ex_pop_namespace(example_ptr ec) {
223 if (ec->indices.size() == 0) return false;
224 unsigned char ns = ec->indices.pop();
225 ec->num_features -= ec->atomics[ns].size();
226 ec->total_sum_feat_sq -= ec->sum_feat_sq[ns];
227 ec->sum_feat_sq[ns] = 0.;
228 ec->atomics[ns].erase();
229 return true;
230 }
231
my_setup_example(vw_ptr vw,example_ptr ec)232 void my_setup_example(vw_ptr vw, example_ptr ec) {
233 VW::setup_example(*vw, ec.get());
234 }
235
ex_set_label_string(example_ptr ec,vw_ptr vw,string label,size_t labelType)236 void ex_set_label_string(example_ptr ec, vw_ptr vw, string label, size_t labelType) {
237 // SPEEDUP: if it's already set properly, don't modify
238 label_parser& old_lp = vw->p->lp;
239 vw->p->lp = *get_label_parser(&*vw, labelType);
240 VW::parse_example_label(*vw, *ec, label);
241 vw->p->lp = old_lp;
242 }
243
ex_get_simplelabel_label(example_ptr ec)244 float ex_get_simplelabel_label(example_ptr ec) { return ec->l.simple.label; }
ex_get_simplelabel_weight(example_ptr ec)245 float ex_get_simplelabel_weight(example_ptr ec) { return ec->l.simple.weight; }
ex_get_simplelabel_initial(example_ptr ec)246 float ex_get_simplelabel_initial(example_ptr ec) { return ec->l.simple.initial; }
ex_get_simplelabel_prediction(example_ptr ec)247 float ex_get_simplelabel_prediction(example_ptr ec) { return ec->pred.scalar; }
248
ex_get_multiclass_label(example_ptr ec)249 uint32_t ex_get_multiclass_label(example_ptr ec) { return ec->l.multi.label; }
ex_get_multiclass_weight(example_ptr ec)250 float ex_get_multiclass_weight(example_ptr ec) { return ec->l.multi.weight; }
ex_get_multiclass_prediction(example_ptr ec)251 uint32_t ex_get_multiclass_prediction(example_ptr ec) { return ec->pred.multiclass; }
252
ex_get_costsensitive_prediction(example_ptr ec)253 uint32_t ex_get_costsensitive_prediction(example_ptr ec) { return ec->pred.multiclass; }
ex_get_costsensitive_num_costs(example_ptr ec)254 uint32_t ex_get_costsensitive_num_costs(example_ptr ec) { return ec->l.cs.costs.size(); }
ex_get_costsensitive_cost(example_ptr ec,uint32_t i)255 float ex_get_costsensitive_cost(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].x; }
ex_get_costsensitive_class(example_ptr ec,uint32_t i)256 uint32_t ex_get_costsensitive_class(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].class_index; }
ex_get_costsensitive_partial_prediction(example_ptr ec,uint32_t i)257 float ex_get_costsensitive_partial_prediction(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].partial_prediction; }
ex_get_costsensitive_wap_value(example_ptr ec,uint32_t i)258 float ex_get_costsensitive_wap_value(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].wap_value; }
259
ex_get_cbandits_prediction(example_ptr ec)260 uint32_t ex_get_cbandits_prediction(example_ptr ec) { return ec->pred.multiclass; }
ex_get_cbandits_num_costs(example_ptr ec)261 uint32_t ex_get_cbandits_num_costs(example_ptr ec) { return ec->l.cb.costs.size(); }
ex_get_cbandits_cost(example_ptr ec,uint32_t i)262 float ex_get_cbandits_cost(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].cost; }
ex_get_cbandits_class(example_ptr ec,uint32_t i)263 uint32_t ex_get_cbandits_class(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].action; }
ex_get_cbandits_probability(example_ptr ec,uint32_t i)264 float ex_get_cbandits_probability(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].probability; }
ex_get_cbandits_partial_prediction(example_ptr ec,uint32_t i)265 float ex_get_cbandits_partial_prediction(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].partial_prediction; }
266
get_example_counter(example_ptr ec)267 size_t get_example_counter(example_ptr ec) { return ec->example_counter; }
get_ft_offset(example_ptr ec)268 uint32_t get_ft_offset(example_ptr ec) { return ec->ft_offset; }
get_num_features(example_ptr ec)269 size_t get_num_features(example_ptr ec) { return ec->num_features; }
get_partial_prediction(example_ptr ec)270 float get_partial_prediction(example_ptr ec) { return ec->partial_prediction; }
get_updated_prediction(example_ptr ec)271 float get_updated_prediction(example_ptr ec) { return ec->updated_prediction; }
get_loss(example_ptr ec)272 float get_loss(example_ptr ec) { return ec->loss; }
get_example_t(example_ptr ec)273 float get_example_t(example_ptr ec) { return ec->example_t; }
get_total_sum_feat_sq(example_ptr ec)274 float get_total_sum_feat_sq(example_ptr ec) { return ec->total_sum_feat_sq; }
275
get_sum_loss(vw_ptr vw)276 double get_sum_loss(vw_ptr vw) { return vw->sd->sum_loss; }
get_weighted_examples(vw_ptr vw)277 double get_weighted_examples(vw_ptr vw) { return vw->sd->weighted_examples; }
278
search_should_output(search_ptr sch)279 bool search_should_output(search_ptr sch) { return sch->output().good(); }
search_output(search_ptr sch,string s)280 void search_output(search_ptr sch, string s) { sch->output() << s; }
281
282 /*
283 uint32_t search_predict_one_all(search_ptr sch, example_ptr ec, uint32_t one_ystar) {
284 return sch->predict(ec.get(), one_ystar, NULL);
285 }
286
287 uint32_t search_predict_one_some(search_ptr sch, example_ptr ec, uint32_t one_ystar, vector<uint32_t>& yallowed) {
288 v_array<uint32_t> yallowed_va;
289 yallowed_va.begin = yallowed.data();
290 yallowed_va.end = yallowed_va.begin + yallowed.size();
291 yallowed_va.end_array = yallowed_va.end;
292 yallowed_va.erase_count = 0;
293 return sch->predict(ec.get(), one_ystar, &yallowed_va);
294 }
295
296 uint32_t search_predict_many_all(search_ptr sch, example_ptr ec, vector<uint32_t>& ystar) {
297 v_array<uint32_t> ystar_va;
298 ystar_va.begin = ystar.data();
299 ystar_va.end = ystar_va.begin + ystar.size();
300 ystar_va.end_array = ystar_va.end;
301 ystar_va.erase_count = 0;
302 return sch->predict(ec.get(), &ystar_va, NULL);
303 }
304
305 uint32_t search_predict_many_some(search_ptr sch, example_ptr ec, vector<uint32_t>& ystar, vector<uint32_t>& yallowed) {
306 v_array<uint32_t> ystar_va;
307 ystar_va.begin = ystar.data();
308 ystar_va.end = ystar_va.begin + ystar.size();
309 ystar_va.end_array = ystar_va.end;
310 ystar_va.erase_count = 0;
311 v_array<uint32_t> yallowed_va;
312 yallowed_va.begin = yallowed.data();
313 yallowed_va.end = yallowed_va.begin + yallowed.size();
314 yallowed_va.end_array = yallowed_va.end;
315 yallowed_va.erase_count = 0;
316 return sch->predict(ec.get(), &ystar_va, &yallowed_va);
317 }
318 */
319
verify_search_set_properly(search_ptr sch)320 void verify_search_set_properly(search_ptr sch) {
321 if (sch->task_name == NULL) {
322 cerr << "set_structured_predict_hook: search task not initialized properly" << endl;
323 throw exception();
324 }
325 if (strcmp(sch->task_name, "hook") != 0) {
326 cerr << "set_structured_predict_hook: trying to set hook when search task is not 'hook'!" << endl;
327 throw exception();
328 }
329 }
330
search_get_num_actions(search_ptr sch)331 uint32_t search_get_num_actions(search_ptr sch) {
332 verify_search_set_properly(sch);
333 HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
334 return d->num_actions;
335 }
336
search_run_fn(Search::search & sch)337 void search_run_fn(Search::search&sch) {
338 try {
339 HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
340 py::object run = *(py::object*)d->run_object;
341 run.attr("__call__")();
342 } catch(...) {
343 PyErr_Print();
344 PyErr_Clear();
345 throw exception();
346 }
347 }
348
search_setup_fn(Search::search & sch)349 void search_setup_fn(Search::search&sch) {
350 try {
351 HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
352 py::object run = *(py::object*)d->setup_object;
353 run.attr("__call__")();
354 } catch(...) {
355 PyErr_Print();
356 PyErr_Clear();
357 throw exception();
358 }
359 }
360
search_takedown_fn(Search::search & sch)361 void search_takedown_fn(Search::search&sch) {
362 try {
363 HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
364 py::object run = *(py::object*)d->takedown_object;
365 run.attr("__call__")();
366 } catch(...) {
367 PyErr_Print();
368 PyErr_Clear();
369 throw exception();
370 }
371 }
372
py_delete_run_object(void * pyobj)373 void py_delete_run_object(void* pyobj) {
374 py::object* o = (py::object*)pyobj;
375 delete o;
376 }
377
set_structured_predict_hook(search_ptr sch,py::object run_object,py::object setup_object,py::object takedown_object)378 void set_structured_predict_hook(search_ptr sch, py::object run_object, py::object setup_object, py::object takedown_object) {
379 verify_search_set_properly(sch);
380 HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
381 d->run_f = &search_run_fn;
382 delete (py::object*)d->run_object; d->run_object = NULL;
383 delete (py::object*)d->setup_object; d->setup_object = NULL;
384 delete (py::object*)d->takedown_object; d->takedown_object = NULL;
385 d->run_object = new py::object(run_object);
386 if (setup_object.ptr() != Py_None) {
387 d->setup_object = new py::object(setup_object);
388 d->run_setup_f = &search_setup_fn;
389 }
390 if (takedown_object.ptr() != Py_None) {
391 d->takedown_object = new py::object(takedown_object);
392 d->run_takedown_f = &search_takedown_fn;
393 }
394 d->delete_run_object = &py_delete_run_object;
395 }
396
my_set_test_only(example_ptr ec,bool val)397 void my_set_test_only(example_ptr ec, bool val) { ec->test_only = val; }
398
po_exists(search_ptr sch,string arg)399 bool po_exists(search_ptr sch, string arg) {
400 HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
401 return (*d->var_map).count(arg) > 0;
402 }
403
po_get_string(search_ptr sch,string arg)404 string po_get_string(search_ptr sch, string arg) {
405 HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
406 return (*d->var_map)[arg].as<string>();
407 }
408
po_get_int(search_ptr sch,string arg)409 int32_t po_get_int(search_ptr sch, string arg) {
410 HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
411 try { return (*d->var_map)[arg].as<int>(); } catch (...) {}
412 try { return (*d->var_map)[arg].as<size_t>(); } catch (...) {}
413 try { return (*d->var_map)[arg].as<uint32_t>(); } catch (...) {}
414 try { return (*d->var_map)[arg].as<uint64_t>(); } catch (...) {}
415 try { return (*d->var_map)[arg].as<uint16_t>(); } catch (...) {}
416 try { return (*d->var_map)[arg].as<int32_t>(); } catch (...) {}
417 try { return (*d->var_map)[arg].as<int64_t>(); } catch (...) {}
418 try { return (*d->var_map)[arg].as<int16_t>(); } catch (...) {}
419 // we know this'll fail but do it anyway to get the exception
420 return (*d->var_map)[arg].as<int>();
421 }
422
po_get(search_ptr sch,string arg)423 PyObject* po_get(search_ptr sch, string arg) {
424 try {
425 return py::incref(py::object(po_get_string(sch, arg)).ptr());
426 } catch (...) {}
427 try {
428 return py::incref(py::object(po_get_int(sch, arg)).ptr());
429 } catch (...) {}
430 // return None
431 return py::incref(py::object().ptr());
432 }
433
my_set_input(predictor_ptr P,example_ptr ec)434 void my_set_input(predictor_ptr P, example_ptr ec) { P->set_input(*ec); }
my_set_input_at(predictor_ptr P,size_t posn,example_ptr ec)435 void my_set_input_at(predictor_ptr P, size_t posn, example_ptr ec) { P->set_input_at(posn, *ec); }
436
my_add_oracle(predictor_ptr P,action a)437 void my_add_oracle(predictor_ptr P, action a) { P->add_oracle(a); }
my_add_oracles(predictor_ptr P,py::list & a)438 void my_add_oracles(predictor_ptr P, py::list& a) { for (size_t i=0; i<len(a); i++) P->add_oracle(py::extract<action>(a[i])); }
my_add_allowed(predictor_ptr P,action a)439 void my_add_allowed(predictor_ptr P, action a) { P->add_allowed(a); }
my_add_alloweds(predictor_ptr P,py::list & a)440 void my_add_alloweds(predictor_ptr P, py::list& a) { for (size_t i=0; i<len(a); i++) P->add_allowed(py::extract<action>(a[i])); }
my_add_condition(predictor_ptr P,ptag t,char c)441 void my_add_condition(predictor_ptr P, ptag t, char c) { P->add_condition(t, c); }
my_add_condition_range(predictor_ptr P,ptag hi,ptag count,char name0)442 void my_add_condition_range(predictor_ptr P, ptag hi, ptag count, char name0) { P->add_condition_range(hi, count, name0); }
my_set_oracle(predictor_ptr P,action a)443 void my_set_oracle(predictor_ptr P, action a) { P->set_oracle(a); }
my_set_oracles(predictor_ptr P,py::list & a)444 void my_set_oracles(predictor_ptr P, py::list& a) { if (len(a) > 0) P->set_oracle(py::extract<action>(a[0])); else P->erase_oracles(); for (size_t i=1; i<len(a); i++) P->add_oracle(py::extract<action>(a[i])); }
my_set_allowed(predictor_ptr P,action a)445 void my_set_allowed(predictor_ptr P, action a) { P->set_allowed(a); }
my_set_alloweds(predictor_ptr P,py::list & a)446 void my_set_alloweds(predictor_ptr P, py::list& a) { if (len(a) > 0) P->set_allowed(py::extract<action>(a[0])); else P->erase_alloweds(); for (size_t i=1; i<len(a); i++) P->add_allowed(py::extract<action>(a[i])); }
my_set_condition(predictor_ptr P,ptag t,char c)447 void my_set_condition(predictor_ptr P, ptag t, char c) { P->set_condition(t, c); }
my_set_condition_range(predictor_ptr P,ptag hi,ptag count,char name0)448 void my_set_condition_range(predictor_ptr P, ptag hi, ptag count, char name0) { P->set_condition_range(hi, count, name0); }
my_set_learner_id(predictor_ptr P,size_t id)449 void my_set_learner_id(predictor_ptr P, size_t id) { P->set_learner_id(id); }
my_set_tag(predictor_ptr P,ptag t)450 void my_set_tag(predictor_ptr P, ptag t) { P->set_tag(t); }
451
452
BOOST_PYTHON_MODULE(pylibvw)453 BOOST_PYTHON_MODULE(pylibvw) {
454 // This will enable user-defined docstrings and python signatures,
455 // while disabling the C++ signatures
456 py::docstring_options local_docstring_options(true, true, false);
457
458 // define the vw class
459 py::class_<vw, vw_ptr>("vw", "the basic VW object that holds with weight vector, parser, etc.", py::no_init)
460 .def("__init__", py::make_constructor(my_initialize))
461 // .def("__del__", &my_finish, "deconstruct the VW object by calling finish")
462 .def("finish", &my_finish, "stop VW by calling finish (and, eg, write weights to disk)")
463 .def("learn", &my_learn, "given a pyvw example, learn (and predict) on that example")
464 .def("learn_string", &my_learn_string, "given an example specified as a string (as in a VW data file), learn on that example")
465 .def("hash_space", &VW::hash_space, "given a namespace (as a string), compute the hash of that namespace")
466 .def("hash_feature", &VW::hash_feature, "given a feature string (arg2) and a hashed namespace (arg3), hash that feature")
467 .def("finish_example", &my_finish_example, "tell VW that you're done with a given example")
468 .def("setup_example", &my_setup_example, "given an example that you've created by hand, prepare it for learning (eg, compute quadratic feature)")
469
470 .def("num_weights", &VW::num_weights, "how many weights are we learning?")
471 .def("get_weight", &VW::get_weight, "get the weight for a particular index")
472 .def("set_weight", &VW::set_weight, "set the weight for a particular index")
473 .def("get_stride", &VW::get_stride, "return the internal stride")
474
475 .def("get_sum_loss", &get_sum_loss, "return the total cumulative loss suffered so far")
476 .def("get_weighted_examples", &get_weighted_examples, "return the total weight of examples so far")
477
478 .def("get_search_ptr", &get_search_ptr, "return a pointer to the search data structure")
479 .def("audit_example", &my_audit_example, "print example audit information")
480
481 .def_readonly("lDefault", lDEFAULT, "Default label type (whatever vw was initialized with) -- used as input to the example() initializer")
482 .def_readonly("lBinary", lBINARY, "Binary label type -- used as input to the example() initializer")
483 .def_readonly("lMulticlass", lMULTICLASS, "Multiclass label type -- used as input to the example() initializer")
484 .def_readonly("lCostSensitive", lCOST_SENSITIVE, "Cost sensitive label type (for LDF!) -- used as input to the example() initializer")
485 .def_readonly("lContextualBandit", lCONTEXTUAL_BANDIT, "Contextual bandit label type -- used as input to the example() initializer")
486 ;
487
488 // define the example class
489 py::class_<example, example_ptr>("example", py::no_init)
490 .def("__init__", py::make_constructor(my_read_example), "Given a string as an argument parse that into a VW example (and run setup on it) -- default to multiclass label type")
491 .def("__init__", py::make_constructor(my_empty_example), "Construct an empty (non setup) example; you must provide a label type (vw.lBinary, vw.lMulticlass, etc.)")
492
493 .def("set_test_only", &my_set_test_only, "Change the test-only bit on an example")
494
495 .def("get_tag", &my_get_tag, "Returns the tag associated with this example")
496 .def("get_topic_prediction", &VW::get_topic_prediction, "For LDA models, returns the topic prediction for the topic id given")
497 .def("get_feature_number", &VW::get_feature_number, "Returns the total number of features for this example")
498
499 .def("get_example_counter", &get_example_counter, "Returns the counter of total number of examples seen up to and including this one")
500 .def("get_ft_offset", &get_ft_offset, "Returns the feature offset for this example (used, eg, by multiclass classification to bulk offset all features)")
501 .def("get_partial_prediction", &get_partial_prediction, "Returns the partial prediction associated with this example")
502 .def("get_updated_prediction", &get_updated_prediction, "Returns the partial prediction as if we had updated it after learning")
503 .def("get_loss", &get_loss, "Returns the loss associated with this example")
504 .def("get_example_t", &get_example_t, "The total sum of importance weights up to and including this example")
505 .def("get_total_sum_feat_sq", &get_total_sum_feat_sq, "The total sum of feature-value squared for this example")
506
507 .def("num_namespaces", &ex_num_namespaces, "The total number of namespaces associated with this example")
508 .def("namespace", &ex_namespace, "Get the namespace id for namespace i (for i = 0.. num_namespaces); specifically returns the ord() of the corresponding character id")
509 .def("sum_feat_sq", &ex_sum_feat_sq, "Get the sum of feature-values squared for a given namespace id (id=character-ord)")
510 .def("num_features_in", &ex_num_features, "Get the number of features in a given namespace id (id=character-ord)")
511 .def("feature", &ex_feature, "Get the feature id for the ith feature in a given namespace id (id=character-ord)")
512 .def("feature_weight", &ex_feature_weight, "The the feature value (weight) per .feature(...)")
513
514 .def("push_hashed_feature", &ex_push_feature, "Add a hashed feature to a given namespace (id=character-ord)")
515 .def("push_feature_list", &ex_push_feature_list, "Add a (Python) list of features to a given namespace")
516 .def("pop_feature", &ex_pop_feature, "Remove the top feature from a given namespace; returns True iff the list was non-empty")
517 .def("push_namespace", &ex_push_namespace, "Add a new namespace")
518 .def("ensure_namespace_exists", &ex_ensure_namespace_exists, "Add a new namespace if it doesn't already exist")
519 .def("pop_namespace", &ex_pop_namespace, "Remove the top namespace off; returns True iff the list was non-empty")
520
521 .def("set_label_string", &ex_set_label_string, "(Re)assign the label of this example to this string")
522
523 .def("get_simplelabel_label", &ex_get_simplelabel_label, "Assuming a simple_label label type, return the corresponding label (class/regression target/etc.)")
524 .def("get_simplelabel_weight", &ex_get_simplelabel_weight, "Assuming a simple_label label type, return the importance weight")
525 .def("get_simplelabel_initial", &ex_get_simplelabel_initial, "Assuming a simple_label label type, return the initial (baseline) prediction")
526 .def("get_simplelabel_prediction", &ex_get_simplelabel_prediction, "Assuming a simple_label label type, return the final prediction")
527 .def("get_multiclass_label", &ex_get_multiclass_label, "Assuming a multiclass label type, get the true label")
528 .def("get_multiclass_weight", &ex_get_multiclass_weight, "Assuming a multiclass label type, get the importance weight")
529 .def("get_multiclass_prediction", &ex_get_multiclass_prediction, "Assuming a multiclass label type, get the prediction")
530 .def("get_costsensitive_prediction", &ex_get_costsensitive_prediction, "Assuming a cost_sensitive label type, get the prediction")
531 .def("get_costsensitive_num_costs", &ex_get_costsensitive_num_costs, "Assuming a cost_sensitive label type, get the total number of label/cost pairs")
532 .def("get_costsensitive_cost", &ex_get_costsensitive_cost, "Assuming a cost_sensitive label type, get the cost for a given pair (i=0.. get_costsensitive_num_costs)")
533 .def("get_costsensitive_class", &ex_get_costsensitive_class, "Assuming a cost_sensitive label type, get the label for a given pair (i=0.. get_costsensitive_num_costs)")
534 .def("get_costsensitive_partial_prediction", &ex_get_costsensitive_partial_prediction, "Assuming a cost_sensitive label type, get the partial prediction for a given pair (i=0.. get_costsensitive_num_costs)")
535 .def("get_costsensitive_wap_value", &ex_get_costsensitive_wap_value, "Assuming a cost_sensitive label type, get the weighted-all-pairs recomputed cost for a given pair (i=0.. get_costsensitive_num_costs)")
536 .def("get_cbandits_prediction", &ex_get_cbandits_prediction, "Assuming a contextual_bandits label type, get the prediction")
537 .def("get_cbandits_num_costs", &ex_get_cbandits_num_costs, "Assuming a contextual_bandits label type, get the total number of label/cost pairs")
538 .def("get_cbandits_cost", &ex_get_cbandits_cost, "Assuming a contextual_bandits label type, get the cost for a given pair (i=0.. get_cbandits_num_costs)")
539 .def("get_cbandits_class", &ex_get_cbandits_class, "Assuming a contextual_bandits label type, get the label for a given pair (i=0.. get_cbandits_num_costs)")
540 .def("get_cbandits_probability", &ex_get_cbandits_probability, "Assuming a contextual_bandits label type, get the bandits probability for a given pair (i=0.. get_cbandits_num_costs)")
541 .def("get_cbandits_partial_prediction", &ex_get_cbandits_partial_prediction, "Assuming a contextual_bandits label type, get the partial prediction for a given pair (i=0.. get_cbandits_num_costs)")
542 ;
543
544 py::class_<Search::predictor, predictor_ptr>("predictor", py::no_init)
545 .def("set_input", &my_set_input, "set the input (an example) for this predictor (non-LDF mode only)")
546 //.def("set_input_ldf", &my_set_input_ldf, "set the inputs (a list of examples) for this predictor (LDF mode only)")
547 .def("set_input_length", &Search::predictor::set_input_length, "declare the length of an LDF-sequence of examples")
548 .def("set_input_at", &my_set_input_at, "put a given example at position in the LDF sequence (call after set_input_length)")
549 .def("add_oracle", &my_add_oracle, "add an action to the current list of oracle actions")
550 .def("add_oracles", &my_add_oracles, "add a list of actions to the current list of oracle actions")
551 .def("add_allowed", &my_add_allowed, "add an action to the current list of allowed actions")
552 .def("add_alloweds", &my_add_alloweds, "add a list of actions to the current list of allowed actions")
553 .def("add_condition", &my_add_condition, "add a (tag,char) pair to the list of variables on which to condition")
554 .def("add_condition_range", &my_add_condition_range, "given (tag,len,char), add (tag,char), (tag-1,char+1), ..., (tag-len,char+len) to the list of conditionings")
555 .def("set_oracle", &my_set_oracle, "set an action as the current list of oracle actions")
556 .def("set_oracles", &my_set_oracles, "set a list of actions as the current list of oracle actions")
557 .def("set_allowed", &my_set_allowed, "set an action as the current list of allowed actions")
558 .def("set_alloweds", &my_set_alloweds, "set a list of actions as the current list of allowed actions")
559 .def("set_condition", &my_set_condition, "set a (tag,char) pair as the list of variables on which to condition")
560 .def("set_condition_range", &my_set_condition_range, "given (tag,len,char), set (tag,char), (tag-1,char+1), ..., (tag-len,char+len) as the list of conditionings")
561 .def("set_learner_id", &my_set_learner_id, "select the learner with which to make this prediction")
562 .def("set_tag", &my_set_tag, "change the tag of this prediction")
563 .def("predict", &Search::predictor::predict, "make a prediction")
564 ;
565
566 py::class_<Search::search, search_ptr>("search")
567 .def("set_options", &Search::search::set_options, "Set global search options (auto conditioning, etc.)")
568 .def("set_num_learners", &Search::search::set_num_learners, "Set the total number of learners you want to train")
569 .def("get_history_length", &Search::search::get_history_length, "Get the value specified by --search_history_length")
570 .def("loss", &Search::search::loss, "Declare a (possibly incremental) loss")
571 .def("should_output", &search_should_output, "Check whether search wants us to output (only happens if you have -p running)")
572 .def("predict_needs_example", &Search::search::predictNeedsExample, "Check whether a subsequent call to predict is actually going to use the example you pass---i.e., can you skip feature computation?")
573 .def("output", &search_output, "Add a string to the coutput (should only do if should_output returns True)")
574 .def("get_num_actions", &search_get_num_actions, "Return the total number of actions search was initialized with")
575 .def("set_structured_predict_hook", &set_structured_predict_hook, "Set the hook (function pointer) that search should use for structured prediction (you don't want to call this yourself!")
576 .def("is_ldf", &Search::search::is_ldf, "check whether this search task is running in LDF mode")
577
578 .def("po_exists", &po_exists, "For program (cmd line) options, check to see if a given option was specified; eg sch.po_exists(\"search\") should be True")
579 .def("po_get", &po_get, "For program (cmd line) options, if an option was specified, get its value; eg sch.po_get(\"search\") should return the # of actions (returns either int or string)")
580 .def("po_get_str", &po_get_string, "Same as po_get, but specialized for string return values.")
581 .def("po_get_int", &po_get_int, "Same as po_get, but specialized for integer return values.")
582
583 .def("get_predictor", &get_predictor, "Get a predictor object that can be used for making predictions; requires a tag argument to tag the prediction.")
584
585 .def_readonly("AUTO_CONDITION_FEATURES", Search::AUTO_CONDITION_FEATURES, "Tell search to automatically add features based on conditioned-on variables")
586 .def_readonly("AUTO_HAMMING_LOSS", Search::AUTO_HAMMING_LOSS, "Tell search to automatically compute hamming loss over predictions")
587 .def_readonly("EXAMPLES_DONT_CHANGE", Search::EXAMPLES_DONT_CHANGE, "Tell search that on a single structured 'run', you don't change the examples you pass to predict")
588 .def_readonly("IS_LDF", Search::IS_LDF, "Tell search that this is an LDF task")
589 ;
590 }
591