1 #/*
2 COpyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD
4 license as described in the file LICENSE.
5  */
6 #ifndef LIBSEARCH_HOOKTASK_H
7 #define LIBSEARCH_HOOKTASK_H
8 
9 #include "../vowpalwabbit/parser.h"
10 #include "../vowpalwabbit/parse_example.h"
11 #include "../vowpalwabbit/vw.h"
12 #include "../vowpalwabbit/search.h"
13 #include "../vowpalwabbit/search_hooktask.h"
14 
15 using namespace std;
16 
17 template<class INPUT, class OUTPUT> class SearchTask {
18   public:
SearchTask(vw & vw_obj)19   SearchTask(vw& vw_obj) : vw_obj(vw_obj), sch(*(Search::search*)vw_obj.searchstr) {
20     bogus_example = alloc_examples(vw_obj.p->lp.label_size, 1);
21     read_line(vw_obj, bogus_example, (char*)"1 | x");
22     parse_atomic_example(vw_obj, bogus_example, false);
23     VW::setup_example(vw_obj, bogus_example);
24 
25     blank_line = alloc_examples(vw_obj.p->lp.label_size, 1);
26     read_line(vw_obj, blank_line, (char*)"");
27     parse_atomic_example(vw_obj, blank_line, false);
28     VW::setup_example(vw_obj, blank_line);
29 
30     HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
31     d->run_f = _search_run_fn;
32     d->run_setup_f = _search_setup_fn;
33     d->run_takedown_f = _search_takedown_fn;
34     d->run_object = this;
35     d->extra_data  = NULL;
36     d->extra_data2 = NULL;
37   }
~SearchTask()38   ~SearchTask() {
39     dealloc_example(vw_obj.p->lp.delete_label, *bogus_example); free(bogus_example);
40     dealloc_example(vw_obj.p->lp.delete_label, *blank_line);    free(blank_line);
41   }
42 
_run(Search::search & sch,INPUT & input_example,OUTPUT & output)43   virtual void _run(Search::search&sch, INPUT& input_example, OUTPUT& output) {}  // YOU MUST DEFINE THIS FUNCTION!
_setup(Search::search & sch,INPUT & input_example,OUTPUT & output)44   void       _setup(Search::search&sch, INPUT& input_example, OUTPUT& output) {}  // OPTIONAL
_takedown(Search::search & sch,INPUT & input_example,OUTPUT & output)45   void    _takedown(Search::search&sch, INPUT& input_example, OUTPUT& output) {}  // OPTIONAL
46 
learn(INPUT & input_example,OUTPUT & output)47   void   learn(INPUT& input_example, OUTPUT& output) { bogus_example->test_only = false; call_vw(input_example, output); }
predict(INPUT & input_example,OUTPUT & output)48   void predict(INPUT& input_example, OUTPUT& output) { bogus_example->test_only = true;  call_vw(input_example, output); }
49 
50   protected:
51   vw& vw_obj;
52   Search::search& sch;
53 
54   private:
55   example* bogus_example, *blank_line;
56 
call_vw(INPUT & input_example,OUTPUT & output)57   void call_vw(INPUT& input_example, OUTPUT& output) {
58     HookTask::task_data* d = sch.template get_task_data<HookTask::task_data> (); // ugly calling convention :(
59     d->extra_data  = (void*)&input_example;
60     d->extra_data2 = (void*)&output;
61     vw_obj.learn(bogus_example);
62     vw_obj.learn(blank_line);   // this will cause our search_run_fn hook to get called
63   }
64 
_search_run_fn(Search::search & sch)65   static void _search_run_fn(Search::search&sch) {
66     HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
67     if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL)) {
68       cerr << "error: calling _search_run_fn without setting run object" << endl;
69       throw exception();
70     }
71     ((SearchTask*)d->run_object)->_run(sch, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2);
72   }
73 
_search_setup_fn(Search::search & sch)74   static void _search_setup_fn(Search::search&sch) {
75     HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
76     if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL)) {
77       cerr << "error: calling _search_setup_fn without setting run object" << endl;
78       throw exception();
79     }
80     ((SearchTask*)d->run_object)->_setup(sch, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2);
81   }
82 
_search_takedown_fn(Search::search & sch)83   static void _search_takedown_fn(Search::search&sch) {
84     HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
85     if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL)) {
86       cerr << "error: calling _search_takedown_fn without setting run object" << endl;
87       throw exception();
88     }
89     ((SearchTask*)d->run_object)->_takedown(sch, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2);
90   }
91 
92 };
93 
94 
95 class BuiltInTask : public SearchTask< vector<example*>, vector<uint32_t> > {
96   public:
BuiltInTask(vw & vw_obj,Search::search_task * task)97   BuiltInTask(vw& vw_obj, Search::search_task* task)
98       : SearchTask< vector<example*>, vector<uint32_t> >(vw_obj) {
99     HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
100     size_t num_actions = d->num_actions;
101     my_task = task;
102     if (my_task->initialize)
103       my_task->initialize(sch, num_actions, *d->var_map);
104   }
105 
~BuiltInTask()106   ~BuiltInTask() { if (my_task->finish) my_task->finish(sch); }
107 
_run(Search::search & sch,vector<example * > & input_example,vector<uint32_t> & output)108   void _run(Search::search& sch, vector<example*> & input_example, vector<uint32_t> & output) {
109     my_task->run(sch, input_example);
110     sch.get_test_action_sequence(output);
111   }
112 
113   protected:
114   Search::search_task* my_task;
115 };
116 
117 
118 #endif
119