1 #include "../vowpalwabbit/vw.h"
2 #include "../vowpalwabbit/multiclass.h"
3 #include "../vowpalwabbit/cost_sensitive.h"
4 #include "../vowpalwabbit/cb.h"
5 #include "../vowpalwabbit/search.h"
6 #include "../vowpalwabbit/search_hooktask.h"
7 #include "../vowpalwabbit/parse_example.h"
8 #include "../vowpalwabbit/gd.h"
9 
10 #include <boost/make_shared.hpp>
11 #include <boost/python.hpp>
12 #include <boost/python/suite/indexing/vector_indexing_suite.hpp>
13 
14 using namespace std;
15 namespace py=boost::python;
16 
17 typedef boost::shared_ptr<vw> vw_ptr;
18 typedef boost::shared_ptr<example> example_ptr;
19 typedef boost::shared_ptr<Search::search> search_ptr;
20 typedef boost::shared_ptr<Search::predictor> predictor_ptr;
21 
22 const size_t lDEFAULT = 0;
23 const size_t lBINARY = 1;
24 const size_t lMULTICLASS = 2;
25 const size_t lCOST_SENSITIVE = 3;
26 const size_t lCONTEXTUAL_BANDIT = 4;
27 const size_t lMAX = 5;
28 
29 
dont_delete_me(void * arg)30 void dont_delete_me(void*arg) { }
31 
my_initialize(string args)32 vw_ptr my_initialize(string args) {
33   vw*foo = VW::initialize(args);
34   return boost::shared_ptr<vw>(foo, dont_delete_me);
35 }
36 
my_finish(vw_ptr all)37 void my_finish(vw_ptr all) {
38   VW::finish(*all, false);  // don't delete all because python will do that for us!
39 }
40 
get_search_ptr(vw_ptr all)41 search_ptr get_search_ptr(vw_ptr all) {
42   return boost::shared_ptr<Search::search>((Search::search*)(all->searchstr), dont_delete_me);
43 }
44 
my_audit_example(vw_ptr all,example_ptr ec)45 void my_audit_example(vw_ptr all, example_ptr ec) { GD::print_audit_features(*all, *ec); }
46 
get_predictor(search_ptr sch,ptag my_tag)47 predictor_ptr get_predictor(search_ptr sch, ptag my_tag) {
48   Search::predictor* P = new Search::predictor(*sch, my_tag);
49   return boost::shared_ptr<Search::predictor>(P);
50 }
51 
get_label_parser(vw * all,size_t labelType)52 label_parser* get_label_parser(vw*all, size_t labelType) {
53   switch (labelType) {
54     case lDEFAULT:           return all ? &all->p->lp : NULL;
55     case lBINARY:            return &simple_label;
56     case lMULTICLASS:        return &MULTICLASS::mc_label;
57     case lCOST_SENSITIVE:    return &COST_SENSITIVE::cs_label;
58     case lCONTEXTUAL_BANDIT: return &CB::cb_label;
59     default: cerr << "get_label_parser called on invalid label type" << endl; throw exception();
60   }
61 }
62 
my_delete_example(void * voidec)63 void my_delete_example(void*voidec) {
64   example* ec = (example*) voidec;
65   size_t labelType = (ec->tag.size() == 0) ? lDEFAULT : ec->tag[0];
66   label_parser* lp = get_label_parser(NULL, labelType);
67   dealloc_example(lp ? lp->delete_label : NULL, *ec);
68   free(ec);
69 }
70 
71 
my_empty_example0(vw_ptr vw,size_t labelType)72 example* my_empty_example0(vw_ptr vw, size_t labelType) {
73   label_parser* lp = get_label_parser(&*vw, labelType);
74   example* ec = alloc_examples(lp->label_size, 1);
75   lp->default_label(&ec->l);
76   if (labelType == lCOST_SENSITIVE) {
77     COST_SENSITIVE::wclass zero = { 0., 1, 0., 0. };
78     ec->l.cs.costs.push_back(zero);
79   }
80   ec->tag.erase();
81   if (labelType != lDEFAULT)
82     ec->tag.push_back((char)labelType);  // hide the label type in the tag
83   return ec;
84 }
85 
my_empty_example(vw_ptr vw,size_t labelType)86 example_ptr my_empty_example(vw_ptr vw, size_t labelType) {
87   example* ec = my_empty_example0(vw, labelType);
88   return boost::shared_ptr<example>(ec, my_delete_example);
89 }
90 
my_read_example(vw_ptr all,size_t labelType,char * str)91 example_ptr my_read_example(vw_ptr all, size_t labelType, char*str) {
92   example*ec = my_empty_example0(all, labelType);
93   read_line(*all, ec, str);
94   parse_atomic_example(*all, ec, false);
95   VW::setup_example(*all, ec);
96   ec->example_counter = labelType;
97   ec->tag.erase();
98   if (labelType != lDEFAULT)
99     ec->tag.push_back((char)labelType);  // hide the label type in the tag
100   return boost::shared_ptr<example>(ec, my_delete_example);
101 }
102 
my_finish_example(vw_ptr all,example_ptr ec)103 void my_finish_example(vw_ptr all, example_ptr ec) {
104   // TODO
105 }
106 
my_learn(vw_ptr all,example_ptr ec)107 void my_learn(vw_ptr all, example_ptr ec) {
108   all->learn(ec.get());
109 }
110 
my_learn_string(vw_ptr all,char * str)111 float my_learn_string(vw_ptr all, char*str) {
112   example*ec = VW::read_example(*all, str);
113   all->learn(ec);
114   float pp = ec->partial_prediction;
115   VW::finish_example(*all, ec);
116   return pp;
117 }
118 
varray_char_to_string(v_array<char> & a)119 string varray_char_to_string(v_array<char> &a) {
120   string ret = "";
121   for (char*c = a.begin; c != a.end; ++c)
122     ret += *c;
123   return ret;
124 }
125 
my_get_tag(example_ptr ec)126 string my_get_tag(example_ptr ec) {
127   return varray_char_to_string(ec->tag);
128 }
129 
ex_num_namespaces(example_ptr ec)130 uint32_t ex_num_namespaces(example_ptr ec) {
131   return ec->indices.size();
132 }
133 
ex_namespace(example_ptr ec,uint32_t ns)134 unsigned char ex_namespace(example_ptr ec, uint32_t ns) {
135   return ec->indices[ns];
136 }
137 
ex_num_features(example_ptr ec,unsigned char ns)138 uint32_t ex_num_features(example_ptr ec, unsigned char ns) {
139   return ec->atomics[ns].size();
140 }
141 
ex_feature(example_ptr ec,unsigned char ns,uint32_t i)142 uint32_t ex_feature(example_ptr ec, unsigned char ns, uint32_t i) {
143   return ec->atomics[ns][i].weight_index;
144 }
145 
ex_feature_weight(example_ptr ec,unsigned char ns,uint32_t i)146 float ex_feature_weight(example_ptr ec, unsigned char ns, uint32_t i) {
147   return ec->atomics[ns][i].x;
148 }
149 
ex_sum_feat_sq(example_ptr ec,unsigned char ns)150 float ex_sum_feat_sq(example_ptr ec, unsigned char ns) {
151   return ec->sum_feat_sq[ns];
152 }
153 
ex_push_feature(example_ptr ec,unsigned char ns,uint32_t fid,float v)154 void ex_push_feature(example_ptr ec, unsigned char ns, uint32_t fid, float v) {
155   // warning: assumes namespace exists!
156   feature f = { v, fid };
157   ec->atomics[ns].push_back(f);
158   ec->num_features++;
159   ec->sum_feat_sq[ns] += v * v;
160   ec->total_sum_feat_sq += v * v;
161 }
162 
ex_push_feature_list(example_ptr ec,vw_ptr vw,unsigned char ns,py::list & a)163 void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns, py::list& a) {
164   // warning: assumes namespace exists!
165   char ns_str[2] = { (char)ns, 0 };
166   uint32_t ns_hash = VW::hash_space(*vw, ns_str);
167   size_t count = 0; float sum_sq = 0.;
168   for (size_t i=0; i<len(a); i++) {
169     feature f = { 1., 0 };
170     py::object ai = a[i];
171     py::extract<py::tuple> get_tup(ai);
172     if (get_tup.check()) {
173       py::tuple fv = get_tup();
174       if (len(fv) != 2) { cerr << "warning: malformed feature in list" << endl; continue; } // TODO str(ai)
175       py::extract<float> get_val(fv[1]);
176       if (get_val.check())
177         f.x = get_val();
178       else { cerr << "warning: malformed feature in list" << endl; continue; }
179       ai = fv[0];
180     }
181 
182     bool got = false;
183     py::extract<uint32_t> get_int(ai);
184     if (get_int.check()) { f.weight_index = get_int(); got = true; }
185     else {
186       py::extract<string> get_str(ai);
187       if (get_str.check()) {
188         f.weight_index = VW::hash_feature(*vw, get_str(), ns_hash);
189         got = true;
190       } else { cerr << "warning: malformed feature in list" << endl; continue; }
191     }
192     if (got && (f.x != 0.)) {
193       ec->atomics[ns].push_back(f);
194       count++;
195       sum_sq += f.x * f.x;
196     }
197   }
198   ec->num_features += count;
199   ec->sum_feat_sq[ns] += sum_sq;
200   ec->total_sum_feat_sq += sum_sq;
201 }
202 
ex_pop_feature(example_ptr ec,unsigned char ns)203 bool ex_pop_feature(example_ptr ec, unsigned char ns) {
204   if (ec->atomics[ns].size() == 0) return false;
205   feature f = ec->atomics[ns].pop();
206   ec->num_features--;
207   ec->sum_feat_sq[ns] -= f.x * f.x;
208   ec->total_sum_feat_sq -= f.x * f.x;
209   return true;
210 }
211 
ex_push_namespace(example_ptr ec,unsigned char ns)212 void ex_push_namespace(example_ptr ec, unsigned char ns) {
213   ec->indices.push_back(ns);
214 }
215 
ex_ensure_namespace_exists(example_ptr ec,unsigned char ns)216 void ex_ensure_namespace_exists(example_ptr ec, unsigned char ns) {
217   for (unsigned char* nss = ec->indices.begin; nss != ec->indices.end; ++nss)
218     if (ns == *nss) return;
219   ex_push_namespace(ec, ns);
220 }
221 
ex_pop_namespace(example_ptr ec)222 bool ex_pop_namespace(example_ptr ec) {
223   if (ec->indices.size() == 0) return false;
224   unsigned char ns = ec->indices.pop();
225   ec->num_features -= ec->atomics[ns].size();
226   ec->total_sum_feat_sq -= ec->sum_feat_sq[ns];
227   ec->sum_feat_sq[ns] = 0.;
228   ec->atomics[ns].erase();
229   return true;
230 }
231 
my_setup_example(vw_ptr vw,example_ptr ec)232 void my_setup_example(vw_ptr vw, example_ptr ec) {
233   VW::setup_example(*vw, ec.get());
234 }
235 
ex_set_label_string(example_ptr ec,vw_ptr vw,string label,size_t labelType)236 void ex_set_label_string(example_ptr ec, vw_ptr vw, string label, size_t labelType) {
237   // SPEEDUP: if it's already set properly, don't modify
238   label_parser& old_lp = vw->p->lp;
239   vw->p->lp = *get_label_parser(&*vw, labelType);
240   VW::parse_example_label(*vw, *ec, label);
241   vw->p->lp = old_lp;
242 }
243 
ex_get_simplelabel_label(example_ptr ec)244 float ex_get_simplelabel_label(example_ptr ec) { return ec->l.simple.label; }
ex_get_simplelabel_weight(example_ptr ec)245 float ex_get_simplelabel_weight(example_ptr ec) { return ec->l.simple.weight; }
ex_get_simplelabel_initial(example_ptr ec)246 float ex_get_simplelabel_initial(example_ptr ec) { return ec->l.simple.initial; }
ex_get_simplelabel_prediction(example_ptr ec)247 float ex_get_simplelabel_prediction(example_ptr ec) { return ec->pred.scalar; }
248 
ex_get_multiclass_label(example_ptr ec)249 uint32_t ex_get_multiclass_label(example_ptr ec) { return ec->l.multi.label; }
ex_get_multiclass_weight(example_ptr ec)250 float ex_get_multiclass_weight(example_ptr ec) { return ec->l.multi.weight; }
ex_get_multiclass_prediction(example_ptr ec)251 uint32_t ex_get_multiclass_prediction(example_ptr ec) { return ec->pred.multiclass; }
252 
ex_get_costsensitive_prediction(example_ptr ec)253 uint32_t ex_get_costsensitive_prediction(example_ptr ec) { return ec->pred.multiclass; }
ex_get_costsensitive_num_costs(example_ptr ec)254 uint32_t ex_get_costsensitive_num_costs(example_ptr ec) { return ec->l.cs.costs.size(); }
ex_get_costsensitive_cost(example_ptr ec,uint32_t i)255 float ex_get_costsensitive_cost(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].x; }
ex_get_costsensitive_class(example_ptr ec,uint32_t i)256 uint32_t ex_get_costsensitive_class(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].class_index; }
ex_get_costsensitive_partial_prediction(example_ptr ec,uint32_t i)257 float ex_get_costsensitive_partial_prediction(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].partial_prediction; }
ex_get_costsensitive_wap_value(example_ptr ec,uint32_t i)258 float ex_get_costsensitive_wap_value(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].wap_value; }
259 
ex_get_cbandits_prediction(example_ptr ec)260 uint32_t ex_get_cbandits_prediction(example_ptr ec) { return ec->pred.multiclass; }
ex_get_cbandits_num_costs(example_ptr ec)261 uint32_t ex_get_cbandits_num_costs(example_ptr ec) { return ec->l.cb.costs.size(); }
ex_get_cbandits_cost(example_ptr ec,uint32_t i)262 float ex_get_cbandits_cost(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].cost; }
ex_get_cbandits_class(example_ptr ec,uint32_t i)263 uint32_t ex_get_cbandits_class(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].action; }
ex_get_cbandits_probability(example_ptr ec,uint32_t i)264 float ex_get_cbandits_probability(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].probability; }
ex_get_cbandits_partial_prediction(example_ptr ec,uint32_t i)265 float ex_get_cbandits_partial_prediction(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].partial_prediction; }
266 
get_example_counter(example_ptr ec)267 size_t   get_example_counter(example_ptr ec) { return ec->example_counter; }
get_ft_offset(example_ptr ec)268 uint32_t get_ft_offset(example_ptr ec) { return ec->ft_offset; }
get_num_features(example_ptr ec)269 size_t   get_num_features(example_ptr ec) { return ec->num_features; }
get_partial_prediction(example_ptr ec)270 float    get_partial_prediction(example_ptr ec) { return ec->partial_prediction; }
get_updated_prediction(example_ptr ec)271 float    get_updated_prediction(example_ptr ec) { return ec->updated_prediction; }
get_loss(example_ptr ec)272 float    get_loss(example_ptr ec) { return ec->loss; }
get_example_t(example_ptr ec)273 float    get_example_t(example_ptr ec) { return ec->example_t; }
get_total_sum_feat_sq(example_ptr ec)274 float    get_total_sum_feat_sq(example_ptr ec) { return ec->total_sum_feat_sq; }
275 
get_sum_loss(vw_ptr vw)276 double get_sum_loss(vw_ptr vw) { return vw->sd->sum_loss; }
get_weighted_examples(vw_ptr vw)277 double get_weighted_examples(vw_ptr vw) { return vw->sd->weighted_examples; }
278 
search_should_output(search_ptr sch)279 bool search_should_output(search_ptr sch) { return sch->output().good(); }
search_output(search_ptr sch,string s)280 void search_output(search_ptr sch, string s) { sch->output() << s; }
281 
282 /*
283 uint32_t search_predict_one_all(search_ptr sch, example_ptr ec, uint32_t one_ystar) {
284   return sch->predict(ec.get(), one_ystar, NULL);
285 }
286 
287 uint32_t search_predict_one_some(search_ptr sch, example_ptr ec, uint32_t one_ystar, vector<uint32_t>& yallowed) {
288   v_array<uint32_t> yallowed_va;
289   yallowed_va.begin       = yallowed.data();
290   yallowed_va.end         = yallowed_va.begin + yallowed.size();
291   yallowed_va.end_array   = yallowed_va.end;
292   yallowed_va.erase_count = 0;
293   return sch->predict(ec.get(), one_ystar, &yallowed_va);
294 }
295 
296 uint32_t search_predict_many_all(search_ptr sch, example_ptr ec, vector<uint32_t>& ystar) {
297   v_array<uint32_t> ystar_va;
298   ystar_va.begin       = ystar.data();
299   ystar_va.end         = ystar_va.begin + ystar.size();
300   ystar_va.end_array   = ystar_va.end;
301   ystar_va.erase_count = 0;
302   return sch->predict(ec.get(), &ystar_va, NULL);
303 }
304 
305 uint32_t search_predict_many_some(search_ptr sch, example_ptr ec, vector<uint32_t>& ystar, vector<uint32_t>& yallowed) {
306   v_array<uint32_t> ystar_va;
307   ystar_va.begin       = ystar.data();
308   ystar_va.end         = ystar_va.begin + ystar.size();
309   ystar_va.end_array   = ystar_va.end;
310   ystar_va.erase_count = 0;
311   v_array<uint32_t> yallowed_va;
312   yallowed_va.begin       = yallowed.data();
313   yallowed_va.end         = yallowed_va.begin + yallowed.size();
314   yallowed_va.end_array   = yallowed_va.end;
315   yallowed_va.erase_count = 0;
316   return sch->predict(ec.get(), &ystar_va, &yallowed_va);
317 }
318 */
319 
verify_search_set_properly(search_ptr sch)320 void verify_search_set_properly(search_ptr sch) {
321   if (sch->task_name == NULL) {
322     cerr << "set_structured_predict_hook: search task not initialized properly" << endl;
323     throw exception();
324   }
325   if (strcmp(sch->task_name, "hook") != 0) {
326     cerr << "set_structured_predict_hook: trying to set hook when search task is not 'hook'!" << endl;
327     throw exception();
328   }
329 }
330 
search_get_num_actions(search_ptr sch)331 uint32_t search_get_num_actions(search_ptr sch) {
332   verify_search_set_properly(sch);
333   HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
334   return d->num_actions;
335 }
336 
search_run_fn(Search::search & sch)337 void search_run_fn(Search::search&sch) {
338   try {
339     HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
340     py::object run = *(py::object*)d->run_object;
341     run.attr("__call__")();
342   } catch(...) {
343     PyErr_Print();
344     PyErr_Clear();
345     throw exception();
346   }
347 }
348 
search_setup_fn(Search::search & sch)349 void search_setup_fn(Search::search&sch) {
350   try {
351     HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
352     py::object run = *(py::object*)d->setup_object;
353     run.attr("__call__")();
354   } catch(...) {
355     PyErr_Print();
356     PyErr_Clear();
357     throw exception();
358   }
359 }
360 
search_takedown_fn(Search::search & sch)361 void search_takedown_fn(Search::search&sch) {
362   try {
363     HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
364     py::object run = *(py::object*)d->takedown_object;
365     run.attr("__call__")();
366   } catch(...) {
367     PyErr_Print();
368     PyErr_Clear();
369     throw exception();
370   }
371 }
372 
py_delete_run_object(void * pyobj)373 void py_delete_run_object(void* pyobj) {
374   py::object* o = (py::object*)pyobj;
375   delete o;
376 }
377 
set_structured_predict_hook(search_ptr sch,py::object run_object,py::object setup_object,py::object takedown_object)378 void set_structured_predict_hook(search_ptr sch, py::object run_object, py::object setup_object, py::object takedown_object) {
379   verify_search_set_properly(sch);
380   HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
381   d->run_f = &search_run_fn;
382   delete (py::object*)d->run_object; d->run_object = NULL;
383   delete (py::object*)d->setup_object; d->setup_object = NULL;
384   delete (py::object*)d->takedown_object; d->takedown_object = NULL;
385   d->run_object = new py::object(run_object);
386   if (setup_object.ptr() != Py_None) {
387     d->setup_object = new py::object(setup_object);
388     d->run_setup_f = &search_setup_fn;
389   }
390   if (takedown_object.ptr() != Py_None) {
391     d->takedown_object = new py::object(takedown_object);
392     d->run_takedown_f = &search_takedown_fn;
393   }
394   d->delete_run_object = &py_delete_run_object;
395 }
396 
my_set_test_only(example_ptr ec,bool val)397 void my_set_test_only(example_ptr ec, bool val) { ec->test_only = val; }
398 
po_exists(search_ptr sch,string arg)399 bool po_exists(search_ptr sch, string arg) {
400   HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
401   return (*d->var_map).count(arg) > 0;
402 }
403 
po_get_string(search_ptr sch,string arg)404 string po_get_string(search_ptr sch, string arg) {
405   HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
406   return (*d->var_map)[arg].as<string>();
407 }
408 
po_get_int(search_ptr sch,string arg)409 int32_t po_get_int(search_ptr sch, string arg) {
410   HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
411   try { return (*d->var_map)[arg].as<int>(); } catch (...) {}
412   try { return (*d->var_map)[arg].as<size_t>(); } catch (...) {}
413   try { return (*d->var_map)[arg].as<uint32_t>(); } catch (...) {}
414   try { return (*d->var_map)[arg].as<uint64_t>(); } catch (...) {}
415   try { return (*d->var_map)[arg].as<uint16_t>(); } catch (...) {}
416   try { return (*d->var_map)[arg].as<int32_t>(); } catch (...) {}
417   try { return (*d->var_map)[arg].as<int64_t>(); } catch (...) {}
418   try { return (*d->var_map)[arg].as<int16_t>(); } catch (...) {}
419   // we know this'll fail but do it anyway to get the exception
420   return (*d->var_map)[arg].as<int>();
421 }
422 
po_get(search_ptr sch,string arg)423 PyObject* po_get(search_ptr sch, string arg) {
424   try {
425     return py::incref(py::object(po_get_string(sch, arg)).ptr());
426   } catch (...) {}
427   try {
428     return py::incref(py::object(po_get_int(sch, arg)).ptr());
429   } catch (...) {}
430   // return None
431   return py::incref(py::object().ptr());
432 }
433 
my_set_input(predictor_ptr P,example_ptr ec)434 void my_set_input(predictor_ptr P, example_ptr ec) { P->set_input(*ec); }
my_set_input_at(predictor_ptr P,size_t posn,example_ptr ec)435 void my_set_input_at(predictor_ptr P, size_t posn, example_ptr ec) { P->set_input_at(posn, *ec); }
436 
my_add_oracle(predictor_ptr P,action a)437 void my_add_oracle(predictor_ptr P, action a) { P->add_oracle(a); }
my_add_oracles(predictor_ptr P,py::list & a)438 void my_add_oracles(predictor_ptr P, py::list& a) { for (size_t i=0; i<len(a); i++) P->add_oracle(py::extract<action>(a[i])); }
my_add_allowed(predictor_ptr P,action a)439 void my_add_allowed(predictor_ptr P, action a) { P->add_allowed(a); }
my_add_alloweds(predictor_ptr P,py::list & a)440 void my_add_alloweds(predictor_ptr P, py::list& a) { for (size_t i=0; i<len(a); i++) P->add_allowed(py::extract<action>(a[i])); }
my_add_condition(predictor_ptr P,ptag t,char c)441 void my_add_condition(predictor_ptr P, ptag t, char c) { P->add_condition(t, c); }
my_add_condition_range(predictor_ptr P,ptag hi,ptag count,char name0)442 void my_add_condition_range(predictor_ptr P, ptag hi, ptag count, char name0) { P->add_condition_range(hi, count, name0); }
my_set_oracle(predictor_ptr P,action a)443 void my_set_oracle(predictor_ptr P, action a) { P->set_oracle(a); }
my_set_oracles(predictor_ptr P,py::list & a)444 void my_set_oracles(predictor_ptr P, py::list& a) { if (len(a) > 0) P->set_oracle(py::extract<action>(a[0])); else P->erase_oracles(); for (size_t i=1; i<len(a); i++) P->add_oracle(py::extract<action>(a[i])); }
my_set_allowed(predictor_ptr P,action a)445 void my_set_allowed(predictor_ptr P, action a) { P->set_allowed(a); }
my_set_alloweds(predictor_ptr P,py::list & a)446 void my_set_alloweds(predictor_ptr P, py::list& a) { if (len(a) > 0) P->set_allowed(py::extract<action>(a[0])); else P->erase_alloweds(); for (size_t i=1; i<len(a); i++) P->add_allowed(py::extract<action>(a[i])); }
my_set_condition(predictor_ptr P,ptag t,char c)447 void my_set_condition(predictor_ptr P, ptag t, char c) { P->set_condition(t, c); }
my_set_condition_range(predictor_ptr P,ptag hi,ptag count,char name0)448 void my_set_condition_range(predictor_ptr P, ptag hi, ptag count, char name0) { P->set_condition_range(hi, count, name0); }
my_set_learner_id(predictor_ptr P,size_t id)449 void my_set_learner_id(predictor_ptr P, size_t id) { P->set_learner_id(id); }
my_set_tag(predictor_ptr P,ptag t)450 void my_set_tag(predictor_ptr P, ptag t) { P->set_tag(t); }
451 
452 
BOOST_PYTHON_MODULE(pylibvw)453 BOOST_PYTHON_MODULE(pylibvw) {
454   // This will enable user-defined docstrings and python signatures,
455   // while disabling the C++ signatures
456   py::docstring_options local_docstring_options(true, true, false);
457 
458   // define the vw class
459   py::class_<vw, vw_ptr>("vw", "the basic VW object that holds with weight vector, parser, etc.", py::no_init)
460       .def("__init__", py::make_constructor(my_initialize))
461       //      .def("__del__", &my_finish, "deconstruct the VW object by calling finish")
462       .def("finish", &my_finish, "stop VW by calling finish (and, eg, write weights to disk)")
463       .def("learn", &my_learn, "given a pyvw example, learn (and predict) on that example")
464       .def("learn_string", &my_learn_string, "given an example specified as a string (as in a VW data file), learn on that example")
465       .def("hash_space", &VW::hash_space, "given a namespace (as a string), compute the hash of that namespace")
466       .def("hash_feature", &VW::hash_feature, "given a feature string (arg2) and a hashed namespace (arg3), hash that feature")
467       .def("finish_example", &my_finish_example, "tell VW that you're done with a given example")
468       .def("setup_example", &my_setup_example, "given an example that you've created by hand, prepare it for learning (eg, compute quadratic feature)")
469 
470       .def("num_weights", &VW::num_weights, "how many weights are we learning?")
471       .def("get_weight", &VW::get_weight, "get the weight for a particular index")
472       .def("set_weight", &VW::set_weight, "set the weight for a particular index")
473       .def("get_stride", &VW::get_stride, "return the internal stride")
474 
475       .def("get_sum_loss", &get_sum_loss, "return the total cumulative loss suffered so far")
476       .def("get_weighted_examples", &get_weighted_examples, "return the total weight of examples so far")
477 
478       .def("get_search_ptr", &get_search_ptr, "return a pointer to the search data structure")
479       .def("audit_example", &my_audit_example, "print example audit information")
480 
481       .def_readonly("lDefault", lDEFAULT, "Default label type (whatever vw was initialized with) -- used as input to the example() initializer")
482       .def_readonly("lBinary", lBINARY, "Binary label type -- used as input to the example() initializer")
483       .def_readonly("lMulticlass", lMULTICLASS, "Multiclass label type -- used as input to the example() initializer")
484       .def_readonly("lCostSensitive", lCOST_SENSITIVE, "Cost sensitive label type (for LDF!) -- used as input to the example() initializer")
485       .def_readonly("lContextualBandit", lCONTEXTUAL_BANDIT, "Contextual bandit label type -- used as input to the example() initializer")
486       ;
487 
488   // define the example class
489   py::class_<example, example_ptr>("example", py::no_init)
490       .def("__init__", py::make_constructor(my_read_example), "Given a string as an argument parse that into a VW example (and run setup on it) -- default to multiclass label type")
491       .def("__init__", py::make_constructor(my_empty_example), "Construct an empty (non setup) example; you must provide a label type (vw.lBinary, vw.lMulticlass, etc.)")
492 
493       .def("set_test_only", &my_set_test_only, "Change the test-only bit on an example")
494 
495       .def("get_tag", &my_get_tag, "Returns the tag associated with this example")
496       .def("get_topic_prediction", &VW::get_topic_prediction, "For LDA models, returns the topic prediction for the topic id given")
497       .def("get_feature_number", &VW::get_feature_number, "Returns the total number of features for this example")
498 
499       .def("get_example_counter", &get_example_counter, "Returns the counter of total number of examples seen up to and including this one")
500       .def("get_ft_offset", &get_ft_offset, "Returns the feature offset for this example (used, eg, by multiclass classification to bulk offset all features)")
501       .def("get_partial_prediction", &get_partial_prediction, "Returns the partial prediction associated with this example")
502       .def("get_updated_prediction", &get_updated_prediction, "Returns the partial prediction as if we had updated it after learning")
503       .def("get_loss", &get_loss, "Returns the loss associated with this example")
504       .def("get_example_t", &get_example_t, "The total sum of importance weights up to and including this example")
505       .def("get_total_sum_feat_sq", &get_total_sum_feat_sq, "The total sum of feature-value squared for this example")
506 
507       .def("num_namespaces", &ex_num_namespaces, "The total number of namespaces associated with this example")
508       .def("namespace", &ex_namespace, "Get the namespace id for namespace i (for i = 0.. num_namespaces); specifically returns the ord() of the corresponding character id")
509       .def("sum_feat_sq", &ex_sum_feat_sq, "Get the sum of feature-values squared for a given namespace id (id=character-ord)")
510       .def("num_features_in", &ex_num_features, "Get the number of features in a given namespace id (id=character-ord)")
511       .def("feature", &ex_feature, "Get the feature id for the ith feature in a given namespace id (id=character-ord)")
512       .def("feature_weight", &ex_feature_weight, "The the feature value (weight) per .feature(...)")
513 
514       .def("push_hashed_feature", &ex_push_feature, "Add a hashed feature to a given namespace (id=character-ord)")
515       .def("push_feature_list", &ex_push_feature_list, "Add a (Python) list of features to a given namespace")
516       .def("pop_feature", &ex_pop_feature, "Remove the top feature from a given namespace; returns True iff the list was non-empty")
517       .def("push_namespace", &ex_push_namespace, "Add a new namespace")
518       .def("ensure_namespace_exists", &ex_ensure_namespace_exists, "Add a new namespace if it doesn't already exist")
519       .def("pop_namespace", &ex_pop_namespace, "Remove the top namespace off; returns True iff the list was non-empty")
520 
521       .def("set_label_string", &ex_set_label_string, "(Re)assign the label of this example to this string")
522 
523       .def("get_simplelabel_label", &ex_get_simplelabel_label, "Assuming a simple_label label type, return the corresponding label (class/regression target/etc.)")
524       .def("get_simplelabel_weight", &ex_get_simplelabel_weight, "Assuming a simple_label label type, return the importance weight")
525       .def("get_simplelabel_initial", &ex_get_simplelabel_initial, "Assuming a simple_label label type, return the initial (baseline) prediction")
526       .def("get_simplelabel_prediction", &ex_get_simplelabel_prediction, "Assuming a simple_label label type, return the final prediction")
527       .def("get_multiclass_label", &ex_get_multiclass_label, "Assuming a multiclass label type, get the true label")
528       .def("get_multiclass_weight", &ex_get_multiclass_weight, "Assuming a multiclass label type, get the importance weight")
529       .def("get_multiclass_prediction", &ex_get_multiclass_prediction, "Assuming a multiclass label type, get the prediction")
530       .def("get_costsensitive_prediction", &ex_get_costsensitive_prediction, "Assuming a cost_sensitive label type, get the prediction")
531       .def("get_costsensitive_num_costs", &ex_get_costsensitive_num_costs, "Assuming a cost_sensitive label type, get the total number of label/cost pairs")
532       .def("get_costsensitive_cost", &ex_get_costsensitive_cost, "Assuming a cost_sensitive label type, get the cost for a given pair (i=0.. get_costsensitive_num_costs)")
533       .def("get_costsensitive_class", &ex_get_costsensitive_class, "Assuming a cost_sensitive label type, get the label for a given pair (i=0.. get_costsensitive_num_costs)")
534       .def("get_costsensitive_partial_prediction", &ex_get_costsensitive_partial_prediction, "Assuming a cost_sensitive label type, get the partial prediction for a given pair (i=0.. get_costsensitive_num_costs)")
535       .def("get_costsensitive_wap_value", &ex_get_costsensitive_wap_value, "Assuming a cost_sensitive label type, get the weighted-all-pairs recomputed cost for a given pair (i=0.. get_costsensitive_num_costs)")
536       .def("get_cbandits_prediction", &ex_get_cbandits_prediction, "Assuming a contextual_bandits label type, get the prediction")
537       .def("get_cbandits_num_costs", &ex_get_cbandits_num_costs, "Assuming a contextual_bandits label type, get the total number of label/cost pairs")
538       .def("get_cbandits_cost", &ex_get_cbandits_cost, "Assuming a contextual_bandits label type, get the cost for a given pair (i=0.. get_cbandits_num_costs)")
539       .def("get_cbandits_class", &ex_get_cbandits_class, "Assuming a contextual_bandits label type, get the label for a given pair (i=0.. get_cbandits_num_costs)")
540       .def("get_cbandits_probability", &ex_get_cbandits_probability, "Assuming a contextual_bandits label type, get the bandits probability for a given pair (i=0.. get_cbandits_num_costs)")
541       .def("get_cbandits_partial_prediction", &ex_get_cbandits_partial_prediction, "Assuming a contextual_bandits label type, get the partial prediction for a given pair (i=0.. get_cbandits_num_costs)")
542       ;
543 
544   py::class_<Search::predictor, predictor_ptr>("predictor", py::no_init)
545       .def("set_input", &my_set_input, "set the input (an example) for this predictor (non-LDF mode only)")
546       //.def("set_input_ldf", &my_set_input_ldf, "set the inputs (a list of examples) for this predictor (LDF mode only)")
547       .def("set_input_length", &Search::predictor::set_input_length, "declare the length of an LDF-sequence of examples")
548       .def("set_input_at", &my_set_input_at, "put a given example at position in the LDF sequence (call after set_input_length)")
549       .def("add_oracle", &my_add_oracle, "add an action to the current list of oracle actions")
550       .def("add_oracles", &my_add_oracles, "add a list of actions to the current list of oracle actions")
551       .def("add_allowed", &my_add_allowed, "add an action to the current list of allowed actions")
552       .def("add_alloweds", &my_add_alloweds, "add a list of actions to the current list of allowed actions")
553       .def("add_condition", &my_add_condition, "add a (tag,char) pair to the list of variables on which to condition")
554       .def("add_condition_range", &my_add_condition_range, "given (tag,len,char), add (tag,char), (tag-1,char+1), ..., (tag-len,char+len) to the list of conditionings")
555       .def("set_oracle", &my_set_oracle, "set an action as the current list of oracle actions")
556       .def("set_oracles", &my_set_oracles, "set a list of actions as the current list of oracle actions")
557       .def("set_allowed", &my_set_allowed, "set an action as the current list of allowed actions")
558       .def("set_alloweds", &my_set_alloweds, "set a list of actions as the current list of allowed actions")
559       .def("set_condition", &my_set_condition, "set a (tag,char) pair as the list of variables on which to condition")
560       .def("set_condition_range", &my_set_condition_range, "given (tag,len,char), set (tag,char), (tag-1,char+1), ..., (tag-len,char+len) as the list of conditionings")
561       .def("set_learner_id", &my_set_learner_id, "select the learner with which to make this prediction")
562       .def("set_tag", &my_set_tag, "change the tag of this prediction")
563       .def("predict", &Search::predictor::predict, "make a prediction")
564       ;
565 
566   py::class_<Search::search, search_ptr>("search")
567       .def("set_options", &Search::search::set_options, "Set global search options (auto conditioning, etc.)")
568       .def("set_num_learners", &Search::search::set_num_learners, "Set the total number of learners you want to train")
569       .def("get_history_length", &Search::search::get_history_length, "Get the value specified by --search_history_length")
570       .def("loss", &Search::search::loss, "Declare a (possibly incremental) loss")
571       .def("should_output", &search_should_output, "Check whether search wants us to output (only happens if you have -p running)")
572       .def("predict_needs_example", &Search::search::predictNeedsExample, "Check whether a subsequent call to predict is actually going to use the example you pass---i.e., can you skip feature computation?")
573       .def("output", &search_output, "Add a string to the coutput (should only do if should_output returns True)")
574       .def("get_num_actions", &search_get_num_actions, "Return the total number of actions search was initialized with")
575       .def("set_structured_predict_hook", &set_structured_predict_hook, "Set the hook (function pointer) that search should use for structured prediction (you don't want to call this yourself!")
576       .def("is_ldf", &Search::search::is_ldf, "check whether this search task is running in LDF mode")
577 
578       .def("po_exists", &po_exists, "For program (cmd line) options, check to see if a given option was specified; eg sch.po_exists(\"search\") should be True")
579       .def("po_get", &po_get, "For program (cmd line) options, if an option was specified, get its value; eg sch.po_get(\"search\") should return the # of actions (returns either int or string)")
580       .def("po_get_str", &po_get_string, "Same as po_get, but specialized for string return values.")
581       .def("po_get_int", &po_get_int, "Same as po_get, but specialized for integer return values.")
582 
583       .def("get_predictor", &get_predictor, "Get a predictor object that can be used for making predictions; requires a tag argument to tag the prediction.")
584 
585       .def_readonly("AUTO_CONDITION_FEATURES", Search::AUTO_CONDITION_FEATURES, "Tell search to automatically add features based on conditioned-on variables")
586       .def_readonly("AUTO_HAMMING_LOSS", Search::AUTO_HAMMING_LOSS, "Tell search to automatically compute hamming loss over predictions")
587       .def_readonly("EXAMPLES_DONT_CHANGE", Search::EXAMPLES_DONT_CHANGE, "Tell search that on a single structured 'run', you don't change the examples you pass to predict")
588       .def_readonly("IS_LDF", Search::IS_LDF, "Tell search that this is an LDF task")
589       ;
590 }
591