1 #/* 2 COpyright (c) by respective owners including Yahoo!, Microsoft, and 3 individual contributors. All rights reserved. Released under a BSD 4 license as described in the file LICENSE. 5 */ 6 #ifndef LIBSEARCH_HOOKTASK_H 7 #define LIBSEARCH_HOOKTASK_H 8 9 #include "../vowpalwabbit/parser.h" 10 #include "../vowpalwabbit/parse_example.h" 11 #include "../vowpalwabbit/vw.h" 12 #include "../vowpalwabbit/search.h" 13 #include "../vowpalwabbit/search_hooktask.h" 14 15 using namespace std; 16 17 template<class INPUT, class OUTPUT> class SearchTask { 18 public: SearchTask(vw & vw_obj)19 SearchTask(vw& vw_obj) : vw_obj(vw_obj), sch(*(Search::search*)vw_obj.searchstr) { 20 bogus_example = alloc_examples(vw_obj.p->lp.label_size, 1); 21 read_line(vw_obj, bogus_example, (char*)"1 | x"); 22 parse_atomic_example(vw_obj, bogus_example, false); 23 VW::setup_example(vw_obj, bogus_example); 24 25 blank_line = alloc_examples(vw_obj.p->lp.label_size, 1); 26 read_line(vw_obj, blank_line, (char*)""); 27 parse_atomic_example(vw_obj, blank_line, false); 28 VW::setup_example(vw_obj, blank_line); 29 30 HookTask::task_data* d = sch.get_task_data<HookTask::task_data>(); 31 d->run_f = _search_run_fn; 32 d->run_setup_f = _search_setup_fn; 33 d->run_takedown_f = _search_takedown_fn; 34 d->run_object = this; 35 d->extra_data = NULL; 36 d->extra_data2 = NULL; 37 } ~SearchTask()38 ~SearchTask() { 39 dealloc_example(vw_obj.p->lp.delete_label, *bogus_example); free(bogus_example); 40 dealloc_example(vw_obj.p->lp.delete_label, *blank_line); free(blank_line); 41 } 42 _run(Search::search & sch,INPUT & input_example,OUTPUT & output)43 virtual void _run(Search::search&sch, INPUT& input_example, OUTPUT& output) {} // YOU MUST DEFINE THIS FUNCTION! _setup(Search::search & sch,INPUT & input_example,OUTPUT & output)44 void _setup(Search::search&sch, INPUT& input_example, OUTPUT& output) {} // OPTIONAL _takedown(Search::search & sch,INPUT & input_example,OUTPUT & output)45 void _takedown(Search::search&sch, INPUT& input_example, OUTPUT& output) {} // OPTIONAL 46 learn(INPUT & input_example,OUTPUT & output)47 void learn(INPUT& input_example, OUTPUT& output) { bogus_example->test_only = false; call_vw(input_example, output); } predict(INPUT & input_example,OUTPUT & output)48 void predict(INPUT& input_example, OUTPUT& output) { bogus_example->test_only = true; call_vw(input_example, output); } 49 50 protected: 51 vw& vw_obj; 52 Search::search& sch; 53 54 private: 55 example* bogus_example, *blank_line; 56 call_vw(INPUT & input_example,OUTPUT & output)57 void call_vw(INPUT& input_example, OUTPUT& output) { 58 HookTask::task_data* d = sch.template get_task_data<HookTask::task_data> (); // ugly calling convention :( 59 d->extra_data = (void*)&input_example; 60 d->extra_data2 = (void*)&output; 61 vw_obj.learn(bogus_example); 62 vw_obj.learn(blank_line); // this will cause our search_run_fn hook to get called 63 } 64 _search_run_fn(Search::search & sch)65 static void _search_run_fn(Search::search&sch) { 66 HookTask::task_data* d = sch.get_task_data<HookTask::task_data>(); 67 if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL)) { 68 cerr << "error: calling _search_run_fn without setting run object" << endl; 69 throw exception(); 70 } 71 ((SearchTask*)d->run_object)->_run(sch, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2); 72 } 73 _search_setup_fn(Search::search & sch)74 static void _search_setup_fn(Search::search&sch) { 75 HookTask::task_data* d = sch.get_task_data<HookTask::task_data>(); 76 if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL)) { 77 cerr << "error: calling _search_setup_fn without setting run object" << endl; 78 throw exception(); 79 } 80 ((SearchTask*)d->run_object)->_setup(sch, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2); 81 } 82 _search_takedown_fn(Search::search & sch)83 static void _search_takedown_fn(Search::search&sch) { 84 HookTask::task_data* d = sch.get_task_data<HookTask::task_data>(); 85 if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL)) { 86 cerr << "error: calling _search_takedown_fn without setting run object" << endl; 87 throw exception(); 88 } 89 ((SearchTask*)d->run_object)->_takedown(sch, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2); 90 } 91 92 }; 93 94 95 class BuiltInTask : public SearchTask< vector<example*>, vector<uint32_t> > { 96 public: BuiltInTask(vw & vw_obj,Search::search_task * task)97 BuiltInTask(vw& vw_obj, Search::search_task* task) 98 : SearchTask< vector<example*>, vector<uint32_t> >(vw_obj) { 99 HookTask::task_data* d = sch.get_task_data<HookTask::task_data>(); 100 size_t num_actions = d->num_actions; 101 my_task = task; 102 if (my_task->initialize) 103 my_task->initialize(sch, num_actions, *d->var_map); 104 } 105 ~BuiltInTask()106 ~BuiltInTask() { if (my_task->finish) my_task->finish(sch); } 107 _run(Search::search & sch,vector<example * > & input_example,vector<uint32_t> & output)108 void _run(Search::search& sch, vector<example*> & input_example, vector<uint32_t> & output) { 109 my_task->run(sch, input_example); 110 sch.get_test_action_sequence(output); 111 } 112 113 protected: 114 Search::search_task* my_task; 115 }; 116 117 118 #endif 119