1# Copyright 2011 Matt Chaput. All rights reserved.
2#
3# Redistribution and use in source and binary forms, with or without
4# modification, are permitted provided that the following conditions are met:
5#
6#    1. Redistributions of source code must retain the above copyright notice,
7#       this list of conditions and the following disclaimer.
8#
9#    2. Redistributions in binary form must reproduce the above copyright
10#       notice, this list of conditions and the following disclaimer in the
11#       documentation and/or other materials provided with the distribution.
12#
13# THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
14# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
15# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
16# EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
17# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
18# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
19# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
20# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
21# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
22# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23#
24# The views and conclusions contained in the software and documentation are
25# those of the authors and should not be interpreted as representing official
26# policies, either expressed or implied, of Matt Chaput.
27
28import sys, weakref
29
30from whoosh import query
31from whoosh.qparser.common import get_single_text, QueryParserError, attach
32
33
34class SyntaxNode(object):
35    """Base class for nodes that make up the abstract syntax tree (AST) of a
36    parsed user query string. The AST is an intermediate step, generated
37    from the query string, then converted into a :class:`whoosh.query.Query`
38    tree by calling the ``query()`` method on the nodes.
39
40    Instances have the following required attributes:
41
42    ``has_fieldname``
43        True if this node has a ``fieldname`` attribute.
44    ``has_text``
45        True if this node has a ``text`` attribute
46    ``has_boost``
47        True if this node has a ``boost`` attribute.
48    ``startchar``
49        The character position in the original text at which this node started.
50    ``endchar``
51        The character position in the original text at which this node ended.
52    """
53
54    has_fieldname = False
55    has_text = False
56    has_boost = False
57    _parent = None
58
59    def __repr__(self):
60        r = "<"
61        if self.has_fieldname:
62            r += "%r:" % self.fieldname
63        r += self.r()
64        if self.has_boost and self.boost != 1.0:
65            r += " ^%s" % self.boost
66        r += ">"
67        return r
68
69    def r(self):
70        """Returns a basic representation of this node. The base class's
71        ``__repr__`` method calls this, then does the extra busy work of adding
72        fieldname and boost where appropriate.
73        """
74
75        return "%s %r" % (self.__class__.__name__, self.__dict__)
76
77    def apply(self, fn):
78        return self
79
80    def accept(self, fn):
81        def fn_wrapper(n):
82            return fn(n.apply(fn_wrapper))
83
84        return fn_wrapper(self)
85
86    def query(self, parser):
87        """Returns a :class:`whoosh.query.Query` instance corresponding to this
88        syntax tree node.
89        """
90
91        raise NotImplementedError(self.__class__.__name__)
92
93    def is_ws(self):
94        """Returns True if this node is ignorable whitespace.
95        """
96
97        return False
98
99    def is_text(self):
100        return False
101
102    def set_fieldname(self, name, override=False):
103        """Sets the fieldname associated with this node. If ``override`` is
104        False (the default), the fieldname will only be replaced if this node
105        does not already have a fieldname set.
106
107        For nodes that don't have a fieldname, this is a no-op.
108        """
109
110        if not self.has_fieldname:
111            return
112
113        if self.fieldname is None or override:
114            self.fieldname = name
115        return self
116
117    def set_boost(self, boost):
118        """Sets the boost associated with this node.
119
120        For nodes that don't have a boost, this is a no-op.
121        """
122
123        if not self.has_boost:
124            return
125        self.boost = boost
126        return self
127
128    def set_range(self, startchar, endchar):
129        """Sets the character range associated with this node.
130        """
131
132        self.startchar = startchar
133        self.endchar = endchar
134        return self
135
136    # Navigation methods
137
138    def parent(self):
139        if self._parent:
140            return self._parent()
141
142    def next_sibling(self):
143        p = self.parent()
144        if p:
145            return p.node_after(self)
146
147    def prev_sibling(self):
148        p = self.parent()
149        if p:
150            return p.node_before(self)
151
152    def bake(self, parent):
153        self._parent = weakref.ref(parent)
154
155
156class MarkerNode(SyntaxNode):
157    """Base class for nodes that only exist to mark places in the tree.
158    """
159
160    def r(self):
161        return self.__class__.__name__
162
163
164class Whitespace(MarkerNode):
165    """Abstract syntax tree node for ignorable whitespace.
166    """
167
168    def r(self):
169        return " "
170
171    def is_ws(self):
172        return True
173
174
175class FieldnameNode(SyntaxNode):
176    """Abstract syntax tree node for field name assignments.
177    """
178
179    has_fieldname = True
180
181    def __init__(self, fieldname, original):
182        self.fieldname = fieldname
183        self.original = original
184
185    def __repr__(self):
186        return "<%r:>" % self.fieldname
187
188
189class GroupNode(SyntaxNode):
190    """Base class for abstract syntax tree node types that group together
191    sub-nodes.
192
193    Instances have the following attributes:
194
195    ``merging``
196        True if side-by-side instances of this group can be merged into a
197        single group.
198    ``qclass``
199        If a subclass doesn't override ``query()``, the base class will simply
200        wrap this class around the queries returned by the subnodes.
201
202    This class implements a number of list methods for operating on the
203    subnodes.
204    """
205
206    has_boost = True
207    merging = True
208    qclass = None
209
210    def __init__(self, nodes=None, boost=1.0, **kwargs):
211        self.nodes = nodes or []
212        self.boost = boost
213        self.kwargs = kwargs
214
215    def r(self):
216        return "%s %s" % (self.__class__.__name__,
217                          ", ".join(repr(n) for n in self.nodes))
218
219    @property
220    def startchar(self):
221        if not self.nodes:
222            return None
223        return self.nodes[0].startchar
224
225    @property
226    def endchar(self):
227        if not self.nodes:
228            return None
229        return self.nodes[-1].endchar
230
231    def apply(self, fn):
232        return self.__class__(self.type, [fn(node) for node in self.nodes],
233                              boost=self.boost, **self.kwargs)
234
235    def query(self, parser):
236        subs = []
237        for node in self.nodes:
238            subq = node.query(parser)
239            if subq is not None:
240                subs.append(subq)
241
242        q = self.qclass(subs, boost=self.boost, **self.kwargs)
243        return attach(q, self)
244
245    def empty_copy(self):
246        """Returns an empty copy of this group.
247
248        This is used in the common pattern where a filter creates an new
249        group and then adds nodes from the input group to it if they meet
250        certain criteria, then returns the new group::
251
252            def remove_whitespace(parser, group):
253                newgroup = group.empty_copy()
254                for node in group:
255                    if not node.is_ws():
256                        newgroup.append(node)
257                return newgroup
258        """
259
260        c = self.__class__(**self.kwargs)
261        if self.has_boost:
262            c.boost = self.boost
263        if self.has_fieldname:
264            c.fieldname = self.fieldname
265        if self.has_text:
266            c.text = self.text
267        return c
268
269    def set_fieldname(self, name, override=False):
270        SyntaxNode.set_fieldname(self, name, override=override)
271        for node in self.nodes:
272            node.set_fieldname(name, override=override)
273
274    def set_range(self, startchar, endchar):
275        for node in self.nodes:
276            node.set_range(startchar, endchar)
277        return self
278
279    # List-like methods
280
281    def __nonzero__(self):
282        return bool(self.nodes)
283
284    __bool__ = __nonzero__
285
286    def __iter__(self):
287        return iter(self.nodes)
288
289    def __len__(self):
290        return len(self.nodes)
291
292    def __getitem__(self, n):
293        return self.nodes.__getitem__(n)
294
295    def __setitem__(self, n, v):
296        self.nodes.__setitem__(n, v)
297
298    def __delitem__(self, n):
299        self.nodes.__delitem__(n)
300
301    def insert(self, n, v):
302        self.nodes.insert(n, v)
303
304    def append(self, v):
305        self.nodes.append(v)
306
307    def extend(self, vs):
308        self.nodes.extend(vs)
309
310    def pop(self, *args, **kwargs):
311        return self.nodes.pop(*args, **kwargs)
312
313    def reverse(self):
314        self.nodes.reverse()
315
316    def index(self, v):
317        return self.nodes.index(v)
318
319    # Navigation methods
320
321    def bake(self, parent):
322        SyntaxNode.bake(self, parent)
323        for node in self.nodes:
324            node.bake(self)
325
326    def node_before(self, n):
327        try:
328            i = self.nodes.index(n)
329        except ValueError:
330            return
331        if i > 0:
332            return self.nodes[i - 1]
333
334    def node_after(self, n):
335        try:
336            i = self.nodes.index(n)
337        except ValueError:
338            return
339        if i < len(self.nodes) - 2:
340            return self.nodes[i + 1]
341
342
343class BinaryGroup(GroupNode):
344    """Intermediate base class for group nodes that have two subnodes and
345    whose ``qclass`` initializer takes two arguments instead of a list.
346    """
347
348    merging = False
349    has_boost = False
350
351    def query(self, parser):
352        assert len(self.nodes) == 2
353
354        qa = self.nodes[0].query(parser)
355        qb = self.nodes[1].query(parser)
356        if qa is None and qb is None:
357            q = query.NullQuery
358        elif qa is None:
359            q = qb
360        elif qb is None:
361            q = qa
362        else:
363            q = self.qclass(self.nodes[0].query(parser),
364                            self.nodes[1].query(parser))
365
366        return attach(q, self)
367
368
369class Wrapper(GroupNode):
370    """Intermediate base class for nodes that wrap a single sub-node.
371    """
372
373    merging = False
374
375    def query(self, parser):
376        q = self.nodes[0].query(parser)
377        if q:
378            return attach(self.qclass(q), self)
379
380
381class ErrorNode(SyntaxNode):
382    def __init__(self, message, node=None):
383        self.message = message
384        self.node = node
385
386    def r(self):
387        return "ERR %r %r" % (self.node, self.message)
388
389    @property
390    def startchar(self):
391        return self.node.startchar
392
393    @property
394    def endchar(self):
395        return self.node.endchar
396
397    def query(self, parser):
398        if self.node:
399            q = self.node.query(parser)
400        else:
401            q = query.NullQuery
402
403        return attach(query.error_query(self.message, q), self)
404
405
406class AndGroup(GroupNode):
407    qclass = query.And
408
409
410class OrGroup(GroupNode):
411    qclass = query.Or
412
413    @classmethod
414    def factory(cls, scale=1.0):
415        class ScaledOrGroup(OrGroup):
416            def __init__(self, nodes=None, **kwargs):
417                if "scale" in kwargs:
418                    del kwargs["scale"]
419                super(ScaledOrGroup, self).__init__(nodes=nodes, scale=scale,
420                                                    **kwargs)
421        return ScaledOrGroup
422
423
424class DisMaxGroup(GroupNode):
425    qclass = query.DisjunctionMax
426
427
428class OrderedGroup(GroupNode):
429    qclass = query.Ordered
430
431
432class AndNotGroup(BinaryGroup):
433    qclass = query.AndNot
434
435
436class AndMaybeGroup(BinaryGroup):
437    qclass = query.AndMaybe
438
439
440class RequireGroup(BinaryGroup):
441    qclass = query.Require
442
443
444class NotGroup(Wrapper):
445    qclass = query.Not
446
447
448class RangeNode(SyntaxNode):
449    """Syntax node for range queries.
450    """
451
452    has_fieldname = True
453
454    def __init__(self, start, end, startexcl, endexcl):
455        self.start = start
456        self.end = end
457        self.startexcl = startexcl
458        self.endexcl = endexcl
459        self.boost = 1.0
460        self.fieldname = None
461        self.kwargs = {}
462
463    def r(self):
464        b1 = "{" if self.startexcl else "["
465        b2 = "}" if self.endexcl else "]"
466        return "%s%r %r%s" % (b1, self.start, self.end, b2)
467
468    def query(self, parser):
469        fieldname = self.fieldname or parser.fieldname
470        start = self.start
471        end = self.end
472
473        if parser.schema and fieldname in parser.schema:
474            field = parser.schema[fieldname]
475            if field.self_parsing():
476                try:
477                    q = field.parse_range(fieldname, start, end,
478                                          self.startexcl, self.endexcl,
479                                          boost=self.boost)
480                    if q is not None:
481                        return attach(q, self)
482                except QueryParserError:
483                    e = sys.exc_info()[1]
484                    return attach(query.error_query(e), self)
485
486            if start:
487                start = get_single_text(field, start, tokenize=False,
488                                        removestops=False)
489            if end:
490                end = get_single_text(field, end, tokenize=False,
491                                      removestops=False)
492
493        q = query.TermRange(fieldname, start, end, self.startexcl,
494                            self.endexcl, boost=self.boost)
495        return attach(q, self)
496
497
498class TextNode(SyntaxNode):
499    """Intermediate base class for basic nodes that search for text, such as
500    term queries, wildcards, prefixes, etc.
501
502    Instances have the following attributes:
503
504    ``qclass``
505        If a subclass does not override ``query()``, the base class will use
506        this class to construct the query.
507    ``tokenize``
508        If True and the subclass does not override ``query()``, the node's text
509        will be tokenized before constructing the query
510    ``removestops``
511        If True and the subclass does not override ``query()``, and the field's
512        analyzer has a stop word filter, stop words will be removed from the
513        text before constructing the query.
514    """
515
516    has_fieldname = True
517    has_text = True
518    has_boost = True
519    qclass = None
520    tokenize = False
521    removestops = False
522
523    def __init__(self, text):
524        self.fieldname = None
525        self.text = text
526        self.boost = 1.0
527
528    def r(self):
529        return "%s %r" % (self.__class__.__name__, self.text)
530
531    def is_text(self):
532        return True
533
534    def query(self, parser):
535        fieldname = self.fieldname or parser.fieldname
536        termclass = self.qclass or parser.termclass
537        q = parser.term_query(fieldname, self.text, termclass,
538                              boost=self.boost, tokenize=self.tokenize,
539                              removestops=self.removestops)
540        return attach(q, self)
541
542
543class WordNode(TextNode):
544    """Syntax node for term queries.
545    """
546
547    tokenize = True
548    removestops = True
549
550    def r(self):
551        return repr(self.text)
552
553
554# Operators
555
556class Operator(SyntaxNode):
557    """Base class for PrefixOperator, PostfixOperator, and InfixOperator.
558
559    Operators work by moving the nodes they apply to (e.g. for prefix operator,
560    the previous node, for infix operator, the nodes on either side, etc.) into
561    a group node. The group provides the code for what to do with the nodes.
562    """
563
564    def __init__(self, text, grouptype, leftassoc=True):
565        """
566        :param text: the text of the operator in the query string.
567        :param grouptype: the type of group to create in place of the operator
568            and the node(s) it operates on.
569        :param leftassoc: for infix opeators, whether the operator is left
570            associative. use ``leftassoc=False`` for right-associative infix
571            operators.
572        """
573
574        self.text = text
575        self.grouptype = grouptype
576        self.leftassoc = leftassoc
577
578    def r(self):
579        return "OP %r" % self.text
580
581    def replace_self(self, parser, group, position):
582        """Called with the parser, a group, and the position at which the
583        operator occurs in that group. Should return a group with the operator
584        replaced by whatever effect the operator has (e.g. for an infix op,
585        replace the op and the nodes on either side with a sub-group).
586        """
587
588        raise NotImplementedError
589
590
591class PrefixOperator(Operator):
592    def replace_self(self, parser, group, position):
593        length = len(group)
594        del group[position]
595        if position < length - 1:
596            group[position] = self.grouptype([group[position]])
597        return position
598
599
600class PostfixOperator(Operator):
601    def replace_self(self, parser, group, position):
602        del group[position]
603        if position > 0:
604            group[position - 1] = self.grouptype([group[position - 1]])
605        return position
606
607
608class InfixOperator(Operator):
609    def replace_self(self, parser, group, position):
610        la = self.leftassoc
611        gtype = self.grouptype
612        merging = gtype.merging
613
614        if position > 0 and position < len(group) - 1:
615            left = group[position - 1]
616            right = group[position + 1]
617
618            # The first two clauses check whether the "strong" side is already
619            # a group of the type we are going to create. If it is, we just
620            # append the "weak" side to the "strong" side instead of creating
621            # a new group inside the existing one. This is necessary because
622            # we can quickly run into Python's recursion limit otherwise.
623            if merging and la and isinstance(left, gtype):
624                left.append(right)
625                del group[position:position + 2]
626            elif merging and not la and isinstance(right, gtype):
627                right.insert(0, left)
628                del group[position - 1:position + 1]
629                return position - 1
630            else:
631                # Replace the operator and the two surrounding objects
632                group[position - 1:position + 2] = [gtype([left, right])]
633        else:
634            del group[position]
635
636        return position
637
638
639# Functions
640
641def to_word(n):
642    node = WordNode(n.original)
643    node.startchar = n.startchar
644    node.endchar = n.endchar
645    return node
646