1import sys
2import pylibvw
3
4class SearchTask():
5    def __init__(self, vw, sch, num_actions):
6        self.vw = vw
7        self.sch = sch
8        self.blank_line = self.vw.example("")
9        self.blank_line.finish()
10        self.bogus_example = self.vw.example("1 | x")
11
12    def __del__(self):
13        self.bogus_example.finish()
14        pass
15
16    def _run(self, your_own_input_example):
17        pass
18
19    def _call_vw(self, my_example, isTest): # run_fn, setup_fn, takedown_fn, isTest):
20        self._output = None
21        self.bogus_example.set_test_only(isTest)
22        def run(): self._output = self._run(my_example)
23        setup = None
24        takedown = None
25        if callable(getattr(self, "_setup", None)): setup = lambda: self._setup(my_example)
26        if callable(getattr(self, "_takedown", None)): takedown = lambda: self._takedown(my_example)
27        self.sch.set_structured_predict_hook(run, setup, takedown)
28        self.vw.learn(self.bogus_example)
29        self.vw.learn(self.blank_line) # this will cause our ._run hook to get called
30
31    def learn(self, data_iterator):
32        for my_example in data_iterator.__iter__():
33            self._call_vw(my_example, isTest=False);
34
35    def example(self, initStringOrDict=None, labelType=pylibvw.vw.lDefault):
36        """TODO"""
37        if self.sch.predict_needs_example():
38            return self.vw.example(initStringOrDict, labelType)
39        else:
40            return self.vw.example(None, labelType)
41
42    def predict(self, my_example):
43        self._call_vw(my_example, isTest=True);
44        return self._output
45
46class vw(pylibvw.vw):
47    """The pyvw.vw object is a (trivial) wrapper around the pylibvw.vw
48    object; you're probably best off using this directly and ignoring
49    the pylibvw.vw structure entirely."""
50
51    def __init__(self, argString=None, **kw):
52        """Initialize the vw object. The (optional) argString is the
53        same as the command line arguments you'd use to run vw (eg,"--audit").
54        you can also use key/value pairs as in:
55          pyvw.vw(audit=True, b=24, k=True, c=True, l2=0.001)
56        or a combination, for instance:
57          pyvw.vw("--audit", b=26)"""
58        def format(key,val):
59            if type(val) is bool and val == False: return ''
60            s = ('-'+key) if len(key) == 1 else ('--'+key)
61            if type(val) is not bool or val != True: s += ' ' + str(val)
62            return s
63        l = [format(k,v) for k,v in kw.iteritems()]
64        if argString is not None: l = [argString] + l
65        #print ' '.join(l)
66        pylibvw.vw.__init__(self,' '.join(l))
67        self.finished = False
68
69    def get_weight(self, index, offset=0):
70        """Given an (integer) index (and an optional offset), return
71        the weight for that position in the (learned) weight vector."""
72        return pylibvw.vw.get_weight(self, index, offset)
73
74    def learn(self, ec):
75        """Perform an online update; ec can either be an example
76        object or a string (in which case it is parsed and then
77        learned on)."""
78        if isinstance(ec, str):
79            self.learn_string(ec)
80        else:
81            if hasattr(ec, 'setup_done') and not ec.setup_done:
82                ec.setup_example()
83            pylibvw.vw.learn(self, ec)
84
85    def finish(self):
86        """stop VW by calling finish (and, eg, write weights to disk)"""
87        if not self.finished:
88            pylibvw.vw.finish(self)
89            self.finished = True
90
91    def example(self, stringOrDict=None, labelType=pylibvw.vw.lDefault):
92        """TODO: document"""
93        return example(self, stringOrDict, labelType)
94
95    def __del__(self):
96        self.finish()
97
98    def init_search_task(self, search_task, task_data=None):
99        sch = self.get_search_ptr()
100
101        def predict(examples, my_tag, oracle, condition=None, allowed=None, learner_id=0):
102            """The basic (via-reduction) prediction mechanism. Several
103            variants are supported through this overloaded function:
104
105              'examples' can be a single example (interpreted as
106                 non-LDF mode) or a list of examples (interpreted as
107                 LDF mode).  it can also be a lambda function that
108                 returns a single example or list of examples, and in
109                 that list, each element can also be a lambda function
110                 that returns an example. this is done for lazy
111                 example construction (aka speed).
112
113              'my_tag' should be an integer id, specifying this prediction
114
115              'oracle' can be a single label (or in LDF mode a single
116                 array index in 'examples') or a list of such labels if
117                 the oracle policy is indecisive; if it is None, then
118                 the oracle doesn't care
119
120              'condition' should be either: (1) a (tag,char) pair, indicating
121                 to condition on the given tag with identifier from the char;
122                 or (2) a (tag,len,char) triple, indicating to condition on
123                 tag, tag-1, tag-2, ..., tag-len with identifiers char,
124                 char+1, char+2, ..., char+len. or it can be a (heterogenous)
125                 list of such things.
126
127              'allowed' can be None, in which case all actions are allowed;
128                 or it can be list of valid actions (in LDF mode, this should
129                 be None and you should encode the valid actions in 'examples')
130
131              'learner_id' specifies the underlying learner id
132
133            Returns a single prediction.
134
135            """
136
137            P = sch.get_predictor(my_tag)
138            if sch.is_ldf():
139                # we need to know how many actions there are, even if we don't know their identities
140                while hasattr(examples, '__call__'): examples = examples()
141                if not isinstance(examples, list): raise TypeError('expected example _list_ in LDF mode for SearchTask.predict()')
142                P.set_input_length(len(examples))
143                if sch.predict_needs_example():
144                    for n in range(len(examples)):
145                        ec = examples[n]
146                        while hasattr(ec, '__call__'): ec = ec()   # unfold the lambdas
147                        if not isinstance(ec, example) and not isinstance(ec, pylibvw.example): raise TypeError('non-example in LDF example list in SearchTask.predict()')
148                        P.set_input_at(n, ec)
149                else:
150                    pass # TODO: do we need to set the examples even though they're not used?
151            else:
152                if sch.predict_needs_example():
153                    while hasattr(examples, '__call__'): examples = examples()
154                    P.set_input(examples)
155                else:
156                    pass # TODO: do we need to set the examples even though they're not used?
157
158            # if (isinstance(examples, list) and all([isinstance(ex, example) or isinstance(ex, pylibvw.example) for ex in examples])) or \
159            #    isinstance(examples, example) or isinstance(examples, pylibvw.example):
160            #     if isinstance(examples, list): # LDF
161            #         P.set_input_length(len(examples))
162            #         for n in range(len(examples)):
163            #             P.set_input_at(n, examples[n])
164            #     else: # non-LDF
165            #         P.set_input(examples)
166            if True:   # TODO: get rid of this
167                if oracle is None: pass
168                elif isinstance(oracle, list):
169                    if len(oracle) > 0: P.set_oracles(oracle)
170                elif isinstance(oracle, int): P.set_oracle(oracle)
171                else: raise TypeError('expecting oracle to be a list or an integer')
172
173                if condition is not None:
174                    if not isinstance(condition, list): condition = [condition]
175                    for c in condition:
176                        if not isinstance(c, tuple): raise TypeError('item ' + str(c) + ' in condition list is malformed')
177                        if   len(c) == 2 and isinstance(c[0], int) and isinstance(c[1], str) and len(c[1]) == 1:
178                            P.add_condition(max(0, c[0]), c[1])
179                        elif len(c) == 3 and isinstance(c[0], int) and isinstance(c[1], int) and isinstance(c[2], str) and len(c[2]) == 1:
180                            P.add_condition_range(max(0,c[0]), max(0,c[1]), c[2])
181                        else:
182                            raise TypeError('item ' + str(c) + ' in condition list malformed')
183
184                if allowed is None: pass
185                elif isinstance(allowed, list):
186                    P.set_alloweds(allowed)
187                else: raise TypeError('allowed argument wrong type')
188
189                if learner_id != 0: P.set_learner_id(learner_id)
190
191                p = P.predict()
192                return p
193            else:
194                raise TypeError("'examples' should be a pyvw example (or a pylibvw example), or a list of said things")
195
196        sch.predict = predict
197        num_actions = sch.get_num_actions()
198        return search_task(self, sch, num_actions) if task_data is None else search_task(self, sch, num_actions, task_data)
199
200class namespace_id():
201    """The namespace_id class is simply a wrapper to convert between
202    hash spaces referred to by character (eg 'x') versus their index
203    in a particular example. Mostly used internally, you shouldn't
204    really need to touch this."""
205
206    def __init__(self, ex, id):
207        """Given an example and an id, construct a namespace_id. The
208        id can either be an integer (in which case we take it to be an
209        index into ex.indices[]) or a string (in which case we take
210        the first character as the namespace id)."""
211        if isinstance(id, int):  # you've specified a namespace by index
212            if id < 0 or id >= ex.num_namespaces():
213                raise Exception('namespace ' + str(id) + ' out of bounds')
214            self.id = id
215            self.ord_ns = ex.namespace(id)
216            self.ns = chr(self.ord_ns)
217        elif isinstance(id, str):   # you've specified a namespace by string
218            if len(id) == 0:
219                id = ' '
220            self.id = None  # we don't know and we don't want to do the linear search requered to find it
221            self.ns = id[0]
222            self.ord_ns = ord(self.ns)
223        else:
224            raise Exception("ns_to_characterord failed because id type is unknown: " + str(type(id)))
225
226class example_namespace():
227    """The example_namespace class is a helper class that allows you
228    to extract namespaces from examples and operate at a namespace
229    level rather than an example level. Mainly this is done to enable
230    indexing like ex['x'][0] to get the 0th feature in namespace 'x'
231    in example ex."""
232
233    def __init__(self, ex, ns, ns_hash=None):
234        """Construct an example_namespace given an example and a
235        target namespace (ns should be a namespace_id)"""
236        if not isinstance(ns, namespace_id):
237            raise TypeError
238        self.ex = ex
239        self.ns = ns
240        self.ns_hash = None
241
242    def num_features_in(self):
243        """Return the total number of features in this namespace."""
244        return self.ex.num_features_in(self.ns)
245
246    def __getitem__(self, i):
247        """Get the feature/value pair for the ith feature in this
248        namespace."""
249        f = self.ex.feature(self.ns, i)
250        v = self.ex.feature_weight(self.ns, i)
251        return (f, v)
252
253    def iter_features(self):
254        """iterate over all feature/value pairs in this namespace."""
255        for i in range(self.num_features_in()):
256            yield self[i]
257
258    def push_feature(self, feature, v=1.):
259        """Add an unhashed feature to the current namespace (fails if
260        setup has already run on this example)."""
261        if self.ns_hash is None:
262            self.ns_hash = self.ex.vw.hash_space( self.ns )
263        self.ex.push_feature(self.ns, feature, v, self.ns_hash)
264
265    def pop_feature(self):
266        """Remove the top feature from the current namespace; returns True
267        if a feature was removed, returns False if there were no
268        features to pop. Fails if setup has run."""
269        return self.ex.pop_feature(self.ns)
270
271    def push_features(self, ns, featureList):
272        """Push a list of features to a given namespace. Each feature
273        in the list can either be an integer (already hashed) or a
274        string (to be hashed) and may be paired with a value or not
275        (if not, the value is assumed to be 1.0). See example.push_features
276        for examples."""
277        self.ex.push_features(self.ns, featureList)
278
279class abstract_label:
280    """An abstract class for a VW label."""
281    def __init__(self):
282        pass
283
284    def from_example(self, ex):
285        """grab a label from a given VW example"""
286        raise Exception("from_example not yet implemented")
287
288class simple_label(abstract_label):
289    def __init__(self, label=0., weight=1., initial=0., prediction=0.):
290        abstract_label.__init__(self)
291        if isinstance(label, example):
292            self.from_example(label)
293        else:
294            self.label      = label
295            self.weight     = weight
296            self.initial    = initial
297            self.prediction = prediction
298
299    def from_example(self, ex):
300        self.label      = ex.get_simplelabel_label()
301        self.weight     = ex.get_simplelabel_weight()
302        self.initial    = ex.get_simplelabel_initial()
303        self.prediction = ex.get_simplelabel_prediction()
304
305    def __str__(self):
306        s = str(self.label)
307        if self.weight != 1.:
308            s += ':' + self.weight
309        return s
310
311class multiclass_label(abstract_label):
312    def __init__(self, label=1, weight=1., prediction=1):
313        abstract_label.__init__(self)
314        self.label      = label
315        self.weight     = weight
316        self.prediction = prediction
317
318    def from_example(self, ex):
319        self.label      = ex.get_multiclass_label()
320        self.weight     = ex.get_multiclass_weight()
321        self.prediction = ex.get_multiclass_prediction()
322
323    def __str__(self):
324        s = str(self.label)
325        if self.weight != 1.:
326            s += ':' + self.weight
327        return s
328
329class cost_sensitive_label(abstract_label):
330    class wclass:
331        def __init__(self, label, cost=0., partial_prediction=0., wap_value=0.):
332            self.label = label
333            self.cost = cost
334            self.partial_prediction = partial_prediction
335            self.wap_value = wap_value
336
337    def __init__(self, costs=[], prediction=0):
338        abstract_label.__init__(self)
339        self.costs = costs
340        self.prediction = prediction
341
342    def from_example(self, ex):
343        self.prediction = ex.get_costsensitive_prediction()
344        self.costs = []
345        for i in range(ex.get_costsensitive_num_costs):
346            wc = wclass(ex.get_costsensitive_class(),
347                        ex.get_costsensitive_cost(),
348                        ex.get_costsensitive_partial_prediction(),
349                        ex.get_costsensitive_wap_value())
350            self.costs.append(wc)
351
352    def __str__(self):
353        return '[' + ' '.join([str(c.label) + ':' + str(c.cost) for c in self.costs])
354
355class cbandits_label(abstract_label):
356    class wclass:
357        def __init__(self, label, cost=0., partial_prediction=0., probability=0.):
358            self.label = label
359            self.cost = cost
360            self.partial_prediction = partial_prediction
361            self.probability = probability
362
363    def __init__(self, costs=[], prediction=0):
364        abstract_label.__init__(self)
365        self.costs = costs
366        self.prediction = prediction
367
368    def from_example(self, ex):
369        self.prediction = ex.get_cbandits_prediction()
370        self.costs = []
371        for i in range(ex.get_cbandits_num_costs):
372            wc = wclass(ex.get_cbandits_class(),
373                        ex.get_cbandits_cost(),
374                        ex.get_cbandits_partial_prediction(),
375                        ex.get_cbandits_probability())
376            self.costs.append(wc)
377
378    def __str__(self):
379        return '[' + ' '.join([str(c.label) + ':' + str(c.cost) for c in self.costs])
380
381class example(pylibvw.example):
382    """The example class is a (non-trivial) wrapper around
383    pylibvw.example. Most of the wrapping is to make the interface
384    easier to use (by making the types safer via namespace_id) and
385    also with added python-specific functionality."""
386
387    def __init__(self, vw, initStringOrDict=None, labelType=pylibvw.vw.lDefault):
388        """Construct a new example from vw. If initString is None, you
389        get an "empty" example which you can construct by hand (see, eg,
390        example.push_features). If initString is a string, then this
391        string is parsed as it would be from a VW data file into an
392        example (and "setup_example" is run). if it is a dict, then we add all features in that dictionary. finally, if it's a function, we (repeatedly) execute it fn() until it's not a function any more (for lazy feature computation)."""
393
394        while hasattr(initStringOrDict, '__call__'):
395            initStringOrDict = initStringOrDict()
396
397        if initStringOrDict is None:
398            pylibvw.example.__init__(self, vw, labelType)
399            self.setup_done = False
400        elif isinstance(initStringOrDict, str):
401            pylibvw.example.__init__(self, vw, labelType, initStringOrDict)
402            self.setup_done = True
403        elif isinstance(initStringOrDict, dict):
404            pylibvw.example.__init__(self, vw, labelType)
405            self.vw = vw
406            self.stride = vw.get_stride()
407            self.finished = False
408            self.setup_done = False
409            for ns_char,feats in initStringOrDict.iteritems():
410                self.push_features(ns_char, feats)
411            self.setup_example()
412        else:
413            raise TypeError('expecting string or dict as argument for example construction')
414
415        self.vw = vw
416        self.stride = vw.get_stride()
417        self.finished = False
418        self.labelType = labelType
419
420    def __del__(self):
421        self.finish()
422
423    def __enter__(self):
424        return self
425
426    def __exit__(self,typ,value,traceback):
427        self.finish()
428        return typ is None
429
430    def get_ns(self, id):
431        """Construct a namespace_id from either an integer or string
432        (or, if a namespace_id is fed it, just return it directly)."""
433        if isinstance(id, namespace_id):
434            return id
435        else:
436            return namespace_id(self, id)
437
438    def __getitem__(self, id):
439        """Get an example_namespace object associated with the given
440        namespace id."""
441        return example_namespace(self, self.get_ns(id))
442
443    def feature(self, ns, i):
444        """Get the i-th hashed feature id in a given namespace (i can
445        range from 0 to self.num_features_in(ns)-1)"""
446        ns = self.get_ns(ns)  # guaranteed to be a single character
447        f = pylibvw.example.feature(self, ns.ord_ns, i)
448        if self.setup_done:
449            f = (f - self.get_ft_offset()) / self.stride
450        return f
451
452    def feature_weight(self, ns, i):
453        """Get the value(weight) associated with a given feature id in
454        a given namespace (i can range from 0 to
455        self.num_features_in(ns)-1)"""
456        return pylibvw.example.feature_weight(self, self.get_ns(ns).ord_ns, i)
457
458    def set_label_string(self, string):
459        """Give this example a new label, formatted as a string (ala
460        the VW data file format)."""
461        pylibvw.example.set_label_string(self, self.vw, string, self.labelType)
462
463    def setup_example(self):
464        """If this example hasn't already been setup (ie, quadratic
465        features constructed, etc.), do so."""
466        if self.setup_done:
467            raise Exception('trying to setup_example on an example that is already setup')
468        self.vw.setup_example(self)
469        self.setup_done = True
470
471    def learn(self):
472        """Learn on this example (and before learning, automatically
473        call setup_example if the example hasn't yet been setup)."""
474        if not self.setup_done:
475            self.setup_example()
476        self.vw.learn(self)
477
478    def sum_feat_sq(self, ns):
479        """Return the total sum feature-value squared for a given
480        namespace."""
481        return pylibvw.example.sum_feat_sq(self, self.get_ns(ns).ord_ns)
482
483    def num_features_in(self, ns):
484        """Return the total number of features in a given namespace."""
485        return pylibvw.example.num_features_in(self, self.get_ns(ns).ord_ns)
486
487    def get_feature_id(self, ns, feature, ns_hash=None):
488        """Return the hashed feature id for a given feature in a given
489        namespace. feature can either be an integer (already a feature
490        id) or a string, in which case it is hashed. Note that if
491        --hash all is on, then get_feature_id(ns,"5") !=
492        get_feature_id(ns, 5). If you've already hashed the namespace,
493        you can optionally provide that value to avoid re-hashing it."""
494        if isinstance(feature, int):
495            return feature
496        if isinstance(feature, str):
497            if ns_hash is None:
498                ns_hash = self.vw.hash_space( self.get_ns(ns).ns )
499            return self.vw.hash_feature(feature, ns_hash)
500        raise Exception("cannot extract feature of type: " + str(type(feature)))
501
502
503    def push_hashed_feature(self, ns, f, v=1.):
504        """Add a hashed feature to a given namespace (fails if setup
505        has already run on this example). Fails if setup has run."""
506        if self.setup_done: raise Exception("error: modification to example after setup")
507        pylibvw.example.push_hashed_feature(self, self.get_ns(ns).ord_ns, f, v)
508
509    def push_feature(self, ns, feature, v=1., ns_hash=None):
510        """Add an unhashed feature to a given namespace (fails if
511        setup has already run on this example)."""
512        f = self.get_feature_id(ns, feature, ns_hash)
513        self.push_hashed_feature(ns, f, v)
514
515    def pop_feature(self, ns):
516        """Remove the top feature from a given namespace; returns True
517        if a feature was removed, returns False if there were no
518        features to pop. Fails if setup has run."""
519        if self.setup_done: raise Exception("error: modification to example after setup")
520        return pylibvw.example.pop_feature(self, self.get_ns(ns).ord_ns)
521
522    def push_namespace(self, ns):
523        """Push a new namespace onto this example. You should only do
524        this if you're sure that this example doesn't already have the
525        given namespace. Fails if setup has run."""
526        if self.setup_done: raise Exception("error: modification to example after setup")
527        pylibvw.example.push_namespace(self, self.get_ns(ns).ord_ns)
528
529    def pop_namespace(self):
530        """Remove the top namespace from an example; returns True if a
531        namespace was removed, or False if there were no namespaces
532        left. Fails if setup has run."""
533        if self.setup_done: raise Exception("error: modification to example after setup")
534        return pylibvw.example.pop_namespace(self)
535
536    def ensure_namespace_exists(self, ns):
537        """Check to see if a namespace already exists. If it does, do
538        nothing. If it doesn't, add it. Fails if setup has run."""
539        if self.setup_done: raise Exception("error: modification to example after setup")
540        return pylibvw.example.ensure_namespace_exists(self, self.get_ns(ns).ord_ns)
541
542    def push_features(self, ns, featureList):
543        """Push a list of features to a given namespace. Each feature
544        in the list can either be an integer (already hashed) or a
545        string (to be hashed) and may be paired with a value or not
546        (if not, the value is assumed to be 1.0).
547
548        Examples:
549           ex.push_features('x', ['a', 'b'])
550           ex.push_features('y', [('c', 1.), 'd'])
551
552           space_hash = vw.hash_space( 'x' )
553           feat_hash  = vw.hash_feature( 'a', space_hash )
554           ex.push_features('x', [feat_hash])    # note: 'x' should match the space_hash!
555
556        Fails if setup has run."""
557        ns = self.get_ns(ns)
558        self.ensure_namespace_exists(ns)
559        self.push_feature_list(self.vw, ns.ord_ns, featureList)   # much faster just to do it in C++
560        # ns_hash = self.vw.hash_space( ns.ns )
561        # for feature in featureList:
562        #     if isinstance(feature, int) or isinstance(feature, str):
563        #         f = feature
564        #         v = 1.
565        #     elif isinstance(feature, tuple) and len(feature) == 2 and (isinstance(feature[0], int) or isinstance(feature[0], str)) and (isinstance(feature[1], int) or isinstance(feature[1], float)):
566        #         f = feature[0]
567        #         v = feature[1]
568        #     else:
569        #         raise Exception('malformed feature to push of type: ' + str(type(feature)))
570        #     self.push_feature(ns, f, v, ns_hash)
571
572
573    def finish(self):
574        """Tell VW that you're done with this example and it can
575        recycle it for later use."""
576        if not self.finished:
577            self.vw.finish_example(self)
578            self.finished = True
579
580    def iter_features(self):
581        """Iterate over all feature/value pairs in this example (all
582        namespace included)."""
583        for ns_id in range( self.num_namespaces() ):  # iterate over every namespace
584            ns = self.get_ns(ns_id)
585            for i in range(self.num_features_in(ns)):
586                f = self.feature(ns, i)
587                v = self.feature_weight(ns, i)
588                yield f,v
589
590    def get_label(self, label_class=simple_label):
591        """Given a known label class (default is simple_label), get
592        the corresponding label structure for this example."""
593        return label_class(self)
594
595#help(example)
596