1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Natural Language Toolkit: TGrep search
5#
6# Copyright (C) 2001-2019 NLTK Project
7# Author: Will Roberts <wildwilhelm@gmail.com>
8# URL: <http://nltk.org/>
9# For license information, see LICENSE.TXT
10
11'''
12============================================
13 TGrep search implementation for NLTK trees
14============================================
15
16This module supports TGrep2 syntax for matching parts of NLTK Trees.
17Note that many tgrep operators require the tree passed to be a
18``ParentedTree``.
19
20External links:
21
22- `Tgrep tutorial <http://www.stanford.edu/dept/linguistics/corpora/cas-tut-tgrep.html>`_
23- `Tgrep2 manual <http://tedlab.mit.edu/~dr/Tgrep2/tgrep2.pdf>`_
24- `Tgrep2 source <http://tedlab.mit.edu/~dr/Tgrep2/>`_
25
26Usage
27=====
28
29>>> from nltk.tree import ParentedTree
30>>> from nltk.tgrep import tgrep_nodes, tgrep_positions
31>>> tree = ParentedTree.fromstring('(S (NP (DT the) (JJ big) (NN dog)) (VP bit) (NP (DT a) (NN cat)))')
32>>> list(tgrep_nodes('NN', [tree]))
33[[ParentedTree('NN', ['dog']), ParentedTree('NN', ['cat'])]]
34>>> list(tgrep_positions('NN', [tree]))
35[[(0, 2), (2, 1)]]
36>>> list(tgrep_nodes('DT', [tree]))
37[[ParentedTree('DT', ['the']), ParentedTree('DT', ['a'])]]
38>>> list(tgrep_nodes('DT $ JJ', [tree]))
39[[ParentedTree('DT', ['the'])]]
40
41This implementation adds syntax to select nodes based on their NLTK
42tree position.  This syntax is ``N`` plus a Python tuple representing
43the tree position.  For instance, ``N()``, ``N(0,)``, ``N(0,0)`` are
44valid node selectors.  Example:
45
46>>> tree = ParentedTree.fromstring('(S (NP (DT the) (JJ big) (NN dog)) (VP bit) (NP (DT a) (NN cat)))')
47>>> tree[0,0]
48ParentedTree('DT', ['the'])
49>>> tree[0,0].treeposition()
50(0, 0)
51>>> list(tgrep_nodes('N(0,0)', [tree]))
52[[ParentedTree('DT', ['the'])]]
53
54Caveats:
55========
56
57- Link modifiers: "?" and "=" are not implemented.
58- Tgrep compatibility: Using "@" for "!", "{" for "<", "}" for ">" are
59  not implemented.
60- The "=" and "~" links are not implemented.
61
62Known Issues:
63=============
64
65- There are some issues with link relations involving leaf nodes
66  (which are represented as bare strings in NLTK trees).  For
67  instance, consider the tree::
68
69      (S (A x))
70
71  The search string ``* !>> S`` should select all nodes which are not
72  dominated in some way by an ``S`` node (i.e., all nodes which are
73  not descendants of an ``S``).  Clearly, in this tree, the only node
74  which fulfills this criterion is the top node (since it is not
75  dominated by anything).  However, the code here will find both the
76  top node and the leaf node ``x``.  This is because we cannot recover
77  the parent of the leaf, since it is stored as a bare string.
78
79  A possible workaround, when performing this kind of search, would be
80  to filter out all leaf nodes.
81
82Implementation notes
83====================
84
85This implementation is (somewhat awkwardly) based on lambda functions
86which are predicates on a node.  A predicate is a function which is
87either True or False; using a predicate function, we can identify sets
88of nodes with particular properties.  A predicate function, could, for
89instance, return True only if a particular node has a label matching a
90particular regular expression, and has a daughter node which has no
91sisters.  Because tgrep2 search strings can do things statefully (such
92as substituting in macros, and binding nodes with node labels), the
93actual predicate function is declared with three arguments::
94
95    pred = lambda n, m, l: return True # some logic here
96
97``n``
98    is a node in a tree; this argument must always be given
99
100``m``
101    contains a dictionary, mapping macro names onto predicate functions
102
103``l``
104    is a dictionary to map node labels onto nodes in the tree
105
106``m`` and ``l`` are declared to default to ``None``, and so need not be
107specified in a call to a predicate.  Predicates which call other
108predicates must always pass the value of these arguments on.  The
109top-level predicate (constructed by ``_tgrep_exprs_action``) binds the
110macro definitions to ``m`` and initialises ``l`` to an empty dictionary.
111'''
112
113from __future__ import absolute_import, print_function, unicode_literals
114
115import functools
116import re
117
118from six import binary_type, text_type
119
120try:
121    import pyparsing
122except ImportError:
123    print('Warning: nltk.tgrep will not work without the `pyparsing` package')
124    print('installed.')
125
126import nltk.tree
127
128
129class TgrepException(Exception):
130    '''Tgrep exception type.'''
131
132    pass
133
134
135def ancestors(node):
136    '''
137    Returns the list of all nodes dominating the given tree node.
138    This method will not work with leaf nodes, since there is no way
139    to recover the parent.
140    '''
141    results = []
142    try:
143        current = node.parent()
144    except AttributeError:
145        # if node is a leaf, we cannot retrieve its parent
146        return results
147    while current:
148        results.append(current)
149        current = current.parent()
150    return results
151
152
153def unique_ancestors(node):
154    '''
155    Returns the list of all nodes dominating the given node, where
156    there is only a single path of descent.
157    '''
158    results = []
159    try:
160        current = node.parent()
161    except AttributeError:
162        # if node is a leaf, we cannot retrieve its parent
163        return results
164    while current and len(current) == 1:
165        results.append(current)
166        current = current.parent()
167    return results
168
169
170def _descendants(node):
171    '''
172    Returns the list of all nodes which are descended from the given
173    tree node in some way.
174    '''
175    try:
176        treepos = node.treepositions()
177    except AttributeError:
178        return []
179    return [node[x] for x in treepos[1:]]
180
181
182def _leftmost_descendants(node):
183    '''
184    Returns the set of all nodes descended in some way through
185    left branches from this node.
186    '''
187    try:
188        treepos = node.treepositions()
189    except AttributeError:
190        return []
191    return [node[x] for x in treepos[1:] if all(y == 0 for y in x)]
192
193
194def _rightmost_descendants(node):
195    '''
196    Returns the set of all nodes descended in some way through
197    right branches from this node.
198    '''
199    try:
200        rightmost_leaf = max(node.treepositions())
201    except AttributeError:
202        return []
203    return [node[rightmost_leaf[:i]] for i in range(1, len(rightmost_leaf) + 1)]
204
205
206def _istree(obj):
207    '''Predicate to check whether `obj` is a nltk.tree.Tree.'''
208    return isinstance(obj, nltk.tree.Tree)
209
210
211def _unique_descendants(node):
212    '''
213    Returns the list of all nodes descended from the given node, where
214    there is only a single path of descent.
215    '''
216    results = []
217    current = node
218    while current and _istree(current) and len(current) == 1:
219        current = current[0]
220        results.append(current)
221    return results
222
223
224def _before(node):
225    '''
226    Returns the set of all nodes that are before the given node.
227    '''
228    try:
229        pos = node.treeposition()
230        tree = node.root()
231    except AttributeError:
232        return []
233    return [tree[x] for x in tree.treepositions() if x[: len(pos)] < pos[: len(x)]]
234
235
236def _immediately_before(node):
237    '''
238    Returns the set of all nodes that are immediately before the given
239    node.
240
241    Tree node A immediately precedes node B if the last terminal
242    symbol (word) produced by A immediately precedes the first
243    terminal symbol produced by B.
244    '''
245    try:
246        pos = node.treeposition()
247        tree = node.root()
248    except AttributeError:
249        return []
250    # go "upwards" from pos until there is a place we can go to the left
251    idx = len(pos) - 1
252    while 0 <= idx and pos[idx] == 0:
253        idx -= 1
254    if idx < 0:
255        return []
256    pos = list(pos[: idx + 1])
257    pos[-1] -= 1
258    before = tree[pos]
259    return [before] + _rightmost_descendants(before)
260
261
262def _after(node):
263    '''
264    Returns the set of all nodes that are after the given node.
265    '''
266    try:
267        pos = node.treeposition()
268        tree = node.root()
269    except AttributeError:
270        return []
271    return [tree[x] for x in tree.treepositions() if x[: len(pos)] > pos[: len(x)]]
272
273
274def _immediately_after(node):
275    '''
276    Returns the set of all nodes that are immediately after the given
277    node.
278
279    Tree node A immediately follows node B if the first terminal
280    symbol (word) produced by A immediately follows the last
281    terminal symbol produced by B.
282    '''
283    try:
284        pos = node.treeposition()
285        tree = node.root()
286        current = node.parent()
287    except AttributeError:
288        return []
289    # go "upwards" from pos until there is a place we can go to the
290    # right
291    idx = len(pos) - 1
292    while 0 <= idx and pos[idx] == len(current) - 1:
293        idx -= 1
294        current = current.parent()
295    if idx < 0:
296        return []
297    pos = list(pos[: idx + 1])
298    pos[-1] += 1
299    after = tree[pos]
300    return [after] + _leftmost_descendants(after)
301
302
303def _tgrep_node_literal_value(node):
304    '''
305    Gets the string value of a given parse tree node, for comparison
306    using the tgrep node literal predicates.
307    '''
308    return node.label() if _istree(node) else text_type(node)
309
310
311def _tgrep_macro_use_action(_s, _l, tokens):
312    '''
313    Builds a lambda function which looks up the macro name used.
314    '''
315    assert len(tokens) == 1
316    assert tokens[0][0] == '@'
317    macro_name = tokens[0][1:]
318
319    def macro_use(n, m=None, l=None):
320        if m is None or macro_name not in m:
321            raise TgrepException('macro {0} not defined'.format(macro_name))
322        return m[macro_name](n, m, l)
323
324    return macro_use
325
326
327def _tgrep_node_action(_s, _l, tokens):
328    '''
329    Builds a lambda function representing a predicate on a tree node
330    depending on the name of its node.
331    '''
332    # print 'node tokens: ', tokens
333    if tokens[0] == "'":
334        # strip initial apostrophe (tgrep2 print command)
335        tokens = tokens[1:]
336    if len(tokens) > 1:
337        # disjunctive definition of a node name
338        assert list(set(tokens[1::2])) == ['|']
339        # recursively call self to interpret each node name definition
340        tokens = [_tgrep_node_action(None, None, [node]) for node in tokens[::2]]
341        # capture tokens and return the disjunction
342        return (lambda t: lambda n, m=None, l=None: any(f(n, m, l) for f in t))(tokens)
343    else:
344        if hasattr(tokens[0], '__call__'):
345            # this is a previously interpreted parenthetical node
346            # definition (lambda function)
347            return tokens[0]
348        elif tokens[0] == '*' or tokens[0] == '__':
349            return lambda n, m=None, l=None: True
350        elif tokens[0].startswith('"'):
351            assert tokens[0].endswith('"')
352            node_lit = tokens[0][1:-1].replace('\\"', '"').replace('\\\\', '\\')
353            return (
354                lambda s: lambda n, m=None, l=None: _tgrep_node_literal_value(n) == s
355            )(node_lit)
356        elif tokens[0].startswith('/'):
357            assert tokens[0].endswith('/')
358            node_lit = tokens[0][1:-1]
359            return (
360                lambda r: lambda n, m=None, l=None: r.search(
361                    _tgrep_node_literal_value(n)
362                )
363            )(re.compile(node_lit))
364        elif tokens[0].startswith('i@'):
365            node_func = _tgrep_node_action(_s, _l, [tokens[0][2:].lower()])
366            return (
367                lambda f: lambda n, m=None, l=None: f(
368                    _tgrep_node_literal_value(n).lower()
369                )
370            )(node_func)
371        else:
372            return (
373                lambda s: lambda n, m=None, l=None: _tgrep_node_literal_value(n) == s
374            )(tokens[0])
375
376
377def _tgrep_parens_action(_s, _l, tokens):
378    '''
379    Builds a lambda function representing a predicate on a tree node
380    from a parenthetical notation.
381    '''
382    # print 'parenthetical tokens: ', tokens
383    assert len(tokens) == 3
384    assert tokens[0] == '('
385    assert tokens[2] == ')'
386    return tokens[1]
387
388
389def _tgrep_nltk_tree_pos_action(_s, _l, tokens):
390    '''
391    Builds a lambda function representing a predicate on a tree node
392    which returns true if the node is located at a specific tree
393    position.
394    '''
395    # recover the tuple from the parsed sting
396    node_tree_position = tuple(int(x) for x in tokens if x.isdigit())
397    # capture the node's tree position
398    return (
399        lambda i: lambda n, m=None, l=None: (
400            hasattr(n, 'treeposition') and n.treeposition() == i
401        )
402    )(node_tree_position)
403
404
405def _tgrep_relation_action(_s, _l, tokens):
406    '''
407    Builds a lambda function representing a predicate on a tree node
408    depending on its relation to other nodes in the tree.
409    '''
410    # print 'relation tokens: ', tokens
411    # process negation first if needed
412    negated = False
413    if tokens[0] == '!':
414        negated = True
415        tokens = tokens[1:]
416    if tokens[0] == '[':
417        # process square-bracketed relation expressions
418        assert len(tokens) == 3
419        assert tokens[2] == ']'
420        retval = tokens[1]
421    else:
422        # process operator-node relation expressions
423        assert len(tokens) == 2
424        operator, predicate = tokens
425        # A < B       A is the parent of (immediately dominates) B.
426        if operator == '<':
427            retval = lambda n, m=None, l=None: (
428                _istree(n) and any(predicate(x, m, l) for x in n)
429            )
430        # A > B       A is the child of B.
431        elif operator == '>':
432            retval = lambda n, m=None, l=None: (
433                hasattr(n, 'parent')
434                and bool(n.parent())
435                and predicate(n.parent(), m, l)
436            )
437        # A <, B      Synonymous with A <1 B.
438        elif operator == '<,' or operator == '<1':
439            retval = lambda n, m=None, l=None: (
440                _istree(n) and bool(list(n)) and predicate(n[0], m, l)
441            )
442        # A >, B      Synonymous with A >1 B.
443        elif operator == '>,' or operator == '>1':
444            retval = lambda n, m=None, l=None: (
445                hasattr(n, 'parent')
446                and bool(n.parent())
447                and (n is n.parent()[0])
448                and predicate(n.parent(), m, l)
449            )
450        # A <N B      B is the Nth child of A (the first child is <1).
451        elif operator[0] == '<' and operator[1:].isdigit():
452            idx = int(operator[1:])
453            # capture the index parameter
454            retval = (
455                lambda i: lambda n, m=None, l=None: (
456                    _istree(n)
457                    and bool(list(n))
458                    and 0 <= i < len(n)
459                    and predicate(n[i], m, l)
460                )
461            )(idx - 1)
462        # A >N B      A is the Nth child of B (the first child is >1).
463        elif operator[0] == '>' and operator[1:].isdigit():
464            idx = int(operator[1:])
465            # capture the index parameter
466            retval = (
467                lambda i: lambda n, m=None, l=None: (
468                    hasattr(n, 'parent')
469                    and bool(n.parent())
470                    and 0 <= i < len(n.parent())
471                    and (n is n.parent()[i])
472                    and predicate(n.parent(), m, l)
473                )
474            )(idx - 1)
475        # A <' B      B is the last child of A (also synonymous with A <-1 B).
476        # A <- B      B is the last child of A (synonymous with A <-1 B).
477        elif operator == '<\'' or operator == '<-' or operator == '<-1':
478            retval = lambda n, m=None, l=None: (
479                _istree(n) and bool(list(n)) and predicate(n[-1], m, l)
480            )
481        # A >' B      A is the last child of B (also synonymous with A >-1 B).
482        # A >- B      A is the last child of B (synonymous with A >-1 B).
483        elif operator == '>\'' or operator == '>-' or operator == '>-1':
484            retval = lambda n, m=None, l=None: (
485                hasattr(n, 'parent')
486                and bool(n.parent())
487                and (n is n.parent()[-1])
488                and predicate(n.parent(), m, l)
489            )
490        # A <-N B 	  B is the N th-to-last child of A (the last child is <-1).
491        elif operator[:2] == '<-' and operator[2:].isdigit():
492            idx = -int(operator[2:])
493            # capture the index parameter
494            retval = (
495                lambda i: lambda n, m=None, l=None: (
496                    _istree(n)
497                    and bool(list(n))
498                    and 0 <= (i + len(n)) < len(n)
499                    and predicate(n[i + len(n)], m, l)
500                )
501            )(idx)
502        # A >-N B 	  A is the N th-to-last child of B (the last child is >-1).
503        elif operator[:2] == '>-' and operator[2:].isdigit():
504            idx = -int(operator[2:])
505            # capture the index parameter
506            retval = (
507                lambda i: lambda n, m=None, l=None: (
508                    hasattr(n, 'parent')
509                    and bool(n.parent())
510                    and 0 <= (i + len(n.parent())) < len(n.parent())
511                    and (n is n.parent()[i + len(n.parent())])
512                    and predicate(n.parent(), m, l)
513                )
514            )(idx)
515        # A <: B      B is the only child of A
516        elif operator == '<:':
517            retval = lambda n, m=None, l=None: (
518                _istree(n) and len(n) == 1 and predicate(n[0], m, l)
519            )
520        # A >: B      A is the only child of B.
521        elif operator == '>:':
522            retval = lambda n, m=None, l=None: (
523                hasattr(n, 'parent')
524                and bool(n.parent())
525                and len(n.parent()) == 1
526                and predicate(n.parent(), m, l)
527            )
528        # A << B      A dominates B (A is an ancestor of B).
529        elif operator == '<<':
530            retval = lambda n, m=None, l=None: (
531                _istree(n) and any(predicate(x, m, l) for x in _descendants(n))
532            )
533        # A >> B      A is dominated by B (A is a descendant of B).
534        elif operator == '>>':
535            retval = lambda n, m=None, l=None: any(
536                predicate(x, m, l) for x in ancestors(n)
537            )
538        # A <<, B     B is a left-most descendant of A.
539        elif operator == '<<,' or operator == '<<1':
540            retval = lambda n, m=None, l=None: (
541                _istree(n) and any(predicate(x, m, l) for x in _leftmost_descendants(n))
542            )
543        # A >>, B     A is a left-most descendant of B.
544        elif operator == '>>,':
545            retval = lambda n, m=None, l=None: any(
546                (predicate(x, m, l) and n in _leftmost_descendants(x))
547                for x in ancestors(n)
548            )
549        # A <<' B     B is a right-most descendant of A.
550        elif operator == '<<\'':
551            retval = lambda n, m=None, l=None: (
552                _istree(n)
553                and any(predicate(x, m, l) for x in _rightmost_descendants(n))
554            )
555        # A >>' B     A is a right-most descendant of B.
556        elif operator == '>>\'':
557            retval = lambda n, m=None, l=None: any(
558                (predicate(x, m, l) and n in _rightmost_descendants(x))
559                for x in ancestors(n)
560            )
561        # A <<: B     There is a single path of descent from A and B is on it.
562        elif operator == '<<:':
563            retval = lambda n, m=None, l=None: (
564                _istree(n) and any(predicate(x, m, l) for x in _unique_descendants(n))
565            )
566        # A >>: B     There is a single path of descent from B and A is on it.
567        elif operator == '>>:':
568            retval = lambda n, m=None, l=None: any(
569                predicate(x, m, l) for x in unique_ancestors(n)
570            )
571        # A . B       A immediately precedes B.
572        elif operator == '.':
573            retval = lambda n, m=None, l=None: any(
574                predicate(x, m, l) for x in _immediately_after(n)
575            )
576        # A , B       A immediately follows B.
577        elif operator == ',':
578            retval = lambda n, m=None, l=None: any(
579                predicate(x, m, l) for x in _immediately_before(n)
580            )
581        # A .. B      A precedes B.
582        elif operator == '..':
583            retval = lambda n, m=None, l=None: any(
584                predicate(x, m, l) for x in _after(n)
585            )
586        # A ,, B      A follows B.
587        elif operator == ',,':
588            retval = lambda n, m=None, l=None: any(
589                predicate(x, m, l) for x in _before(n)
590            )
591        # A $ B       A is a sister of B (and A != B).
592        elif operator == '$' or operator == '%':
593            retval = lambda n, m=None, l=None: (
594                hasattr(n, 'parent')
595                and bool(n.parent())
596                and any(predicate(x, m, l) for x in n.parent() if x is not n)
597            )
598        # A $. B      A is a sister of and immediately precedes B.
599        elif operator == '$.' or operator == '%.':
600            retval = lambda n, m=None, l=None: (
601                hasattr(n, 'right_sibling')
602                and bool(n.right_sibling())
603                and predicate(n.right_sibling(), m, l)
604            )
605        # A $, B      A is a sister of and immediately follows B.
606        elif operator == '$,' or operator == '%,':
607            retval = lambda n, m=None, l=None: (
608                hasattr(n, 'left_sibling')
609                and bool(n.left_sibling())
610                and predicate(n.left_sibling(), m, l)
611            )
612        # A $.. B     A is a sister of and precedes B.
613        elif operator == '$..' or operator == '%..':
614            retval = lambda n, m=None, l=None: (
615                hasattr(n, 'parent')
616                and hasattr(n, 'parent_index')
617                and bool(n.parent())
618                and any(predicate(x, m, l) for x in n.parent()[n.parent_index() + 1 :])
619            )
620        # A $,, B     A is a sister of and follows B.
621        elif operator == '$,,' or operator == '%,,':
622            retval = lambda n, m=None, l=None: (
623                hasattr(n, 'parent')
624                and hasattr(n, 'parent_index')
625                and bool(n.parent())
626                and any(predicate(x, m, l) for x in n.parent()[: n.parent_index()])
627            )
628        else:
629            raise TgrepException(
630                'cannot interpret tgrep operator "{0}"'.format(operator)
631            )
632    # now return the built function
633    if negated:
634        return (lambda r: (lambda n, m=None, l=None: not r(n, m, l)))(retval)
635    else:
636        return retval
637
638
639def _tgrep_conjunction_action(_s, _l, tokens, join_char='&'):
640    '''
641    Builds a lambda function representing a predicate on a tree node
642    from the conjunction of several other such lambda functions.
643
644    This is prototypically called for expressions like
645    (`tgrep_rel_conjunction`)::
646
647        < NP & < AP < VP
648
649    where tokens is a list of predicates representing the relations
650    (`< NP`, `< AP`, and `< VP`), possibly with the character `&`
651    included (as in the example here).
652
653    This is also called for expressions like (`tgrep_node_expr2`)::
654
655        NP < NN
656        S=s < /NP/=n : s < /VP/=v : n .. v
657
658    tokens[0] is a tgrep_expr predicate; tokens[1:] are an (optional)
659    list of segmented patterns (`tgrep_expr_labeled`, processed by
660    `_tgrep_segmented_pattern_action`).
661    '''
662    # filter out the ampersand
663    tokens = [x for x in tokens if x != join_char]
664    # print 'relation conjunction tokens: ', tokens
665    if len(tokens) == 1:
666        return tokens[0]
667    else:
668        return (
669            lambda ts: lambda n, m=None, l=None: all(
670                predicate(n, m, l) for predicate in ts
671            )
672        )(tokens)
673
674
675def _tgrep_segmented_pattern_action(_s, _l, tokens):
676    '''
677    Builds a lambda function representing a segmented pattern.
678
679    Called for expressions like (`tgrep_expr_labeled`)::
680
681        =s .. =v < =n
682
683    This is a segmented pattern, a tgrep2 expression which begins with
684    a node label.
685
686    The problem is that for segemented_pattern_action (': =v < =s'),
687    the first element (in this case, =v) is specifically selected by
688    virtue of matching a particular node in the tree; to retrieve
689    the node, we need the label, not a lambda function.  For node
690    labels inside a tgrep_node_expr, we need a lambda function which
691    returns true if the node visited is the same as =v.
692
693    We solve this by creating two copies of a node_label_use in the
694    grammar; the label use inside a tgrep_expr_labeled has a separate
695    parse action to the pred use inside a node_expr.  See
696    `_tgrep_node_label_use_action` and
697    `_tgrep_node_label_pred_use_action`.
698    '''
699    # tokens[0] is a string containing the node label
700    node_label = tokens[0]
701    # tokens[1:] is an (optional) list of predicates which must all
702    # hold of the bound node
703    reln_preds = tokens[1:]
704
705    def pattern_segment_pred(n, m=None, l=None):
706        '''This predicate function ignores its node argument.'''
707        # look up the bound node using its label
708        if l is None or node_label not in l:
709            raise TgrepException(
710                'node_label ={0} not bound in pattern'.format(node_label)
711            )
712        node = l[node_label]
713        # match the relation predicates against the node
714        return all(pred(node, m, l) for pred in reln_preds)
715
716    return pattern_segment_pred
717
718
719def _tgrep_node_label_use_action(_s, _l, tokens):
720    '''
721    Returns the node label used to begin a tgrep_expr_labeled.  See
722    `_tgrep_segmented_pattern_action`.
723
724    Called for expressions like (`tgrep_node_label_use`)::
725
726        =s
727
728    when they appear as the first element of a `tgrep_expr_labeled`
729    expression (see `_tgrep_segmented_pattern_action`).
730
731    It returns the node label.
732    '''
733    assert len(tokens) == 1
734    assert tokens[0].startswith('=')
735    return tokens[0][1:]
736
737
738def _tgrep_node_label_pred_use_action(_s, _l, tokens):
739    '''
740    Builds a lambda function representing a predicate on a tree node
741    which describes the use of a previously bound node label.
742
743    Called for expressions like (`tgrep_node_label_use_pred`)::
744
745        =s
746
747    when they appear inside a tgrep_node_expr (for example, inside a
748    relation).  The predicate returns true if and only if its node
749    argument is identical the the node looked up in the node label
750    dictionary using the node's label.
751    '''
752    assert len(tokens) == 1
753    assert tokens[0].startswith('=')
754    node_label = tokens[0][1:]
755
756    def node_label_use_pred(n, m=None, l=None):
757        # look up the bound node using its label
758        if l is None or node_label not in l:
759            raise TgrepException(
760                'node_label ={0} not bound in pattern'.format(node_label)
761            )
762        node = l[node_label]
763        # truth means the given node is this node
764        return n is node
765
766    return node_label_use_pred
767
768
769def _tgrep_bind_node_label_action(_s, _l, tokens):
770    '''
771    Builds a lambda function representing a predicate on a tree node
772    which can optionally bind a matching node into the tgrep2 string's
773    label_dict.
774
775    Called for expressions like (`tgrep_node_expr2`)::
776
777        /NP/
778        @NP=n
779    '''
780    # tokens[0] is a tgrep_node_expr
781    if len(tokens) == 1:
782        return tokens[0]
783    else:
784        # if present, tokens[1] is the character '=', and tokens[2] is
785        # a tgrep_node_label, a string value containing the node label
786        assert len(tokens) == 3
787        assert tokens[1] == '='
788        node_pred = tokens[0]
789        node_label = tokens[2]
790
791        def node_label_bind_pred(n, m=None, l=None):
792            if node_pred(n, m, l):
793                # bind `n` into the dictionary `l`
794                if l is None:
795                    raise TgrepException(
796                        'cannot bind node_label {0}: label_dict is None'.format(
797                            node_label
798                        )
799                    )
800                l[node_label] = n
801                return True
802            else:
803                return False
804
805        return node_label_bind_pred
806
807
808def _tgrep_rel_disjunction_action(_s, _l, tokens):
809    '''
810    Builds a lambda function representing a predicate on a tree node
811    from the disjunction of several other such lambda functions.
812    '''
813    # filter out the pipe
814    tokens = [x for x in tokens if x != '|']
815    # print 'relation disjunction tokens: ', tokens
816    if len(tokens) == 1:
817        return tokens[0]
818    elif len(tokens) == 2:
819        return (lambda a, b: lambda n, m=None, l=None: a(n, m, l) or b(n, m, l))(
820            tokens[0], tokens[1]
821        )
822
823
824def _macro_defn_action(_s, _l, tokens):
825    '''
826    Builds a dictionary structure which defines the given macro.
827    '''
828    assert len(tokens) == 3
829    assert tokens[0] == '@'
830    return {tokens[1]: tokens[2]}
831
832
833def _tgrep_exprs_action(_s, _l, tokens):
834    '''
835    This is the top-lebel node in a tgrep2 search string; the
836    predicate function it returns binds together all the state of a
837    tgrep2 search string.
838
839    Builds a lambda function representing a predicate on a tree node
840    from the disjunction of several tgrep expressions.  Also handles
841    macro definitions and macro name binding, and node label
842    definitions and node label binding.
843    '''
844    if len(tokens) == 1:
845        return lambda n, m=None, l=None: tokens[0](n, None, {})
846    # filter out all the semicolons
847    tokens = [x for x in tokens if x != ';']
848    # collect all macro definitions
849    macro_dict = {}
850    macro_defs = [tok for tok in tokens if isinstance(tok, dict)]
851    for macro_def in macro_defs:
852        macro_dict.update(macro_def)
853    # collect all tgrep expressions
854    tgrep_exprs = [tok for tok in tokens if not isinstance(tok, dict)]
855    # create a new scope for the node label dictionary
856    def top_level_pred(n, m=macro_dict, l=None):
857        label_dict = {}
858        # bind macro definitions and OR together all tgrep_exprs
859        return any(predicate(n, m, label_dict) for predicate in tgrep_exprs)
860
861    return top_level_pred
862
863
864def _build_tgrep_parser(set_parse_actions=True):
865    '''
866    Builds a pyparsing-based parser object for tokenizing and
867    interpreting tgrep search strings.
868    '''
869    tgrep_op = pyparsing.Optional('!') + pyparsing.Regex('[$%,.<>][%,.<>0-9-\':]*')
870    tgrep_qstring = pyparsing.QuotedString(
871        quoteChar='"', escChar='\\', unquoteResults=False
872    )
873    tgrep_node_regex = pyparsing.QuotedString(
874        quoteChar='/', escChar='\\', unquoteResults=False
875    )
876    tgrep_qstring_icase = pyparsing.Regex('i@\\"(?:[^"\\n\\r\\\\]|(?:\\\\.))*\\"')
877    tgrep_node_regex_icase = pyparsing.Regex('i@\\/(?:[^/\\n\\r\\\\]|(?:\\\\.))*\\/')
878    tgrep_node_literal = pyparsing.Regex('[^][ \r\t\n;:.,&|<>()$!@%\'^=]+')
879    tgrep_expr = pyparsing.Forward()
880    tgrep_relations = pyparsing.Forward()
881    tgrep_parens = pyparsing.Literal('(') + tgrep_expr + ')'
882    tgrep_nltk_tree_pos = (
883        pyparsing.Literal('N(')
884        + pyparsing.Optional(
885            pyparsing.Word(pyparsing.nums)
886            + ','
887            + pyparsing.Optional(
888                pyparsing.delimitedList(pyparsing.Word(pyparsing.nums), delim=',')
889                + pyparsing.Optional(',')
890            )
891        )
892        + ')'
893    )
894    tgrep_node_label = pyparsing.Regex('[A-Za-z0-9]+')
895    tgrep_node_label_use = pyparsing.Combine('=' + tgrep_node_label)
896    # see _tgrep_segmented_pattern_action
897    tgrep_node_label_use_pred = tgrep_node_label_use.copy()
898    macro_name = pyparsing.Regex('[^];:.,&|<>()[$!@%\'^=\r\t\n ]+')
899    macro_name.setWhitespaceChars('')
900    macro_use = pyparsing.Combine('@' + macro_name)
901    tgrep_node_expr = (
902        tgrep_node_label_use_pred
903        | macro_use
904        | tgrep_nltk_tree_pos
905        | tgrep_qstring_icase
906        | tgrep_node_regex_icase
907        | tgrep_qstring
908        | tgrep_node_regex
909        | '*'
910        | tgrep_node_literal
911    )
912    tgrep_node_expr2 = (
913        tgrep_node_expr
914        + pyparsing.Literal('=').setWhitespaceChars('')
915        + tgrep_node_label.copy().setWhitespaceChars('')
916    ) | tgrep_node_expr
917    tgrep_node = tgrep_parens | (
918        pyparsing.Optional("'")
919        + tgrep_node_expr2
920        + pyparsing.ZeroOrMore("|" + tgrep_node_expr)
921    )
922    tgrep_brackets = pyparsing.Optional('!') + '[' + tgrep_relations + ']'
923    tgrep_relation = tgrep_brackets | (tgrep_op + tgrep_node)
924    tgrep_rel_conjunction = pyparsing.Forward()
925    tgrep_rel_conjunction << (
926        tgrep_relation
927        + pyparsing.ZeroOrMore(pyparsing.Optional('&') + tgrep_rel_conjunction)
928    )
929    tgrep_relations << tgrep_rel_conjunction + pyparsing.ZeroOrMore(
930        "|" + tgrep_relations
931    )
932    tgrep_expr << tgrep_node + pyparsing.Optional(tgrep_relations)
933    tgrep_expr_labeled = tgrep_node_label_use + pyparsing.Optional(tgrep_relations)
934    tgrep_expr2 = tgrep_expr + pyparsing.ZeroOrMore(':' + tgrep_expr_labeled)
935    macro_defn = (
936        pyparsing.Literal('@') + pyparsing.White().suppress() + macro_name + tgrep_expr2
937    )
938    tgrep_exprs = (
939        pyparsing.Optional(macro_defn + pyparsing.ZeroOrMore(';' + macro_defn) + ';')
940        + tgrep_expr2
941        + pyparsing.ZeroOrMore(';' + (macro_defn | tgrep_expr2))
942        + pyparsing.ZeroOrMore(';').suppress()
943    )
944    if set_parse_actions:
945        tgrep_node_label_use.setParseAction(_tgrep_node_label_use_action)
946        tgrep_node_label_use_pred.setParseAction(_tgrep_node_label_pred_use_action)
947        macro_use.setParseAction(_tgrep_macro_use_action)
948        tgrep_node.setParseAction(_tgrep_node_action)
949        tgrep_node_expr2.setParseAction(_tgrep_bind_node_label_action)
950        tgrep_parens.setParseAction(_tgrep_parens_action)
951        tgrep_nltk_tree_pos.setParseAction(_tgrep_nltk_tree_pos_action)
952        tgrep_relation.setParseAction(_tgrep_relation_action)
953        tgrep_rel_conjunction.setParseAction(_tgrep_conjunction_action)
954        tgrep_relations.setParseAction(_tgrep_rel_disjunction_action)
955        macro_defn.setParseAction(_macro_defn_action)
956        # the whole expression is also the conjunction of two
957        # predicates: the first node predicate, and the remaining
958        # relation predicates
959        tgrep_expr.setParseAction(_tgrep_conjunction_action)
960        tgrep_expr_labeled.setParseAction(_tgrep_segmented_pattern_action)
961        tgrep_expr2.setParseAction(
962            functools.partial(_tgrep_conjunction_action, join_char=':')
963        )
964        tgrep_exprs.setParseAction(_tgrep_exprs_action)
965    return tgrep_exprs.ignore('#' + pyparsing.restOfLine)
966
967
968def tgrep_tokenize(tgrep_string):
969    '''
970    Tokenizes a TGrep search string into separate tokens.
971    '''
972    parser = _build_tgrep_parser(False)
973    if isinstance(tgrep_string, binary_type):
974        tgrep_string = tgrep_string.decode()
975    return list(parser.parseString(tgrep_string))
976
977
978def tgrep_compile(tgrep_string):
979    '''
980    Parses (and tokenizes, if necessary) a TGrep search string into a
981    lambda function.
982    '''
983    parser = _build_tgrep_parser(True)
984    if isinstance(tgrep_string, binary_type):
985        tgrep_string = tgrep_string.decode()
986    return list(parser.parseString(tgrep_string, parseAll=True))[0]
987
988
989def treepositions_no_leaves(tree):
990    '''
991    Returns all the tree positions in the given tree which are not
992    leaf nodes.
993    '''
994    treepositions = tree.treepositions()
995    # leaves are treeposition tuples that are not prefixes of any
996    # other treeposition
997    prefixes = set()
998    for pos in treepositions:
999        for length in range(len(pos)):
1000            prefixes.add(pos[:length])
1001    return [pos for pos in treepositions if pos in prefixes]
1002
1003
1004def tgrep_positions(pattern, trees, search_leaves=True):
1005    """
1006    Return the tree positions in the trees which match the given pattern.
1007
1008    :param pattern: a tgrep search pattern
1009    :type pattern: str or output of tgrep_compile()
1010    :param trees: a sequence of NLTK trees (usually ParentedTrees)
1011    :type trees: iter(ParentedTree) or iter(Tree)
1012    :param search_leaves: whether ot return matching leaf nodes
1013    :type search_leaves: bool
1014    :rtype: iter(tree positions)
1015    """
1016
1017    if isinstance(pattern, (binary_type, text_type)):
1018        pattern = tgrep_compile(pattern)
1019
1020    for tree in trees:
1021        try:
1022            if search_leaves:
1023                positions = tree.treepositions()
1024            else:
1025                positions = treepositions_no_leaves(tree)
1026            yield [position for position in positions if pattern(tree[position])]
1027        except AttributeError:
1028            yield []
1029
1030
1031def tgrep_nodes(pattern, trees, search_leaves=True):
1032    """
1033    Return the tree nodes in the trees which match the given pattern.
1034
1035    :param pattern: a tgrep search pattern
1036    :type pattern: str or output of tgrep_compile()
1037    :param trees: a sequence of NLTK trees (usually ParentedTrees)
1038    :type trees: iter(ParentedTree) or iter(Tree)
1039    :param search_leaves: whether ot return matching leaf nodes
1040    :type search_leaves: bool
1041    :rtype: iter(tree nodes)
1042    """
1043
1044    if isinstance(pattern, (binary_type, text_type)):
1045        pattern = tgrep_compile(pattern)
1046
1047    for tree in trees:
1048        try:
1049            if search_leaves:
1050                positions = tree.treepositions()
1051            else:
1052                positions = treepositions_no_leaves(tree)
1053            yield [tree[position] for position in positions if pattern(tree[position])]
1054        except AttributeError:
1055            yield []
1056