1# Natural Language Toolkit: Logic
2#
3# Author:     Peter Wang
4# Updated by: Dan Garrette <dhgarrette@gmail.com>
5#
6# Copyright (C) 2001-2019 NLTK Project
7# URL: <http://nltk.org>
8# For license information, see LICENSE.TXT
9
10"""
11An implementation of the Hole Semantics model, following Blackburn and Bos,
12Representation and Inference for Natural Language (CSLI, 2005).
13
14The semantic representations are built by the grammar hole.fcfg.
15This module contains driver code to read in sentences and parse them
16according to a hole semantics grammar.
17
18After parsing, the semantic representation is in the form of an underspecified
19representation that is not easy to read.  We use a "plugging" algorithm to
20convert that representation into first-order logic formulas.
21"""
22from __future__ import print_function, unicode_literals
23
24from functools import reduce
25
26from six import itervalues
27
28from nltk import compat
29from nltk.parse import load_parser
30
31from nltk.sem.skolemize import skolemize
32from nltk.sem.logic import (
33    AllExpression,
34    AndExpression,
35    ApplicationExpression,
36    ExistsExpression,
37    IffExpression,
38    ImpExpression,
39    LambdaExpression,
40    NegatedExpression,
41    OrExpression,
42)
43
44
45# Note that in this code there may be multiple types of trees being referred to:
46#
47# 1. parse trees
48# 2. the underspecified representation
49# 3. first-order logic formula trees
50# 4. the search space when plugging (search tree)
51#
52
53
54class Constants(object):
55    ALL = 'ALL'
56    EXISTS = 'EXISTS'
57    NOT = 'NOT'
58    AND = 'AND'
59    OR = 'OR'
60    IMP = 'IMP'
61    IFF = 'IFF'
62    PRED = 'PRED'
63    LEQ = 'LEQ'
64    HOLE = 'HOLE'
65    LABEL = 'LABEL'
66
67    MAP = {
68        ALL: lambda v, e: AllExpression(v.variable, e),
69        EXISTS: lambda v, e: ExistsExpression(v.variable, e),
70        NOT: NegatedExpression,
71        AND: AndExpression,
72        OR: OrExpression,
73        IMP: ImpExpression,
74        IFF: IffExpression,
75        PRED: ApplicationExpression,
76    }
77
78
79class HoleSemantics(object):
80    """
81    This class holds the broken-down components of a hole semantics, i.e. it
82    extracts the holes, labels, logic formula fragments and constraints out of
83    a big conjunction of such as produced by the hole semantics grammar.  It
84    then provides some operations on the semantics dealing with holes, labels
85    and finding legal ways to plug holes with labels.
86    """
87
88    def __init__(self, usr):
89        """
90        Constructor.  `usr' is a ``sem.Expression`` representing an
91        Underspecified Representation Structure (USR).  A USR has the following
92        special predicates:
93        ALL(l,v,n),
94        EXISTS(l,v,n),
95        AND(l,n,n),
96        OR(l,n,n),
97        IMP(l,n,n),
98        IFF(l,n,n),
99        PRED(l,v,n,v[,v]*) where the brackets and star indicate zero or more repetitions,
100        LEQ(n,n),
101        HOLE(n),
102        LABEL(n)
103        where l is the label of the node described by the predicate, n is either
104        a label or a hole, and v is a variable.
105        """
106        self.holes = set()
107        self.labels = set()
108        self.fragments = {}  # mapping of label -> formula fragment
109        self.constraints = set()  # set of Constraints
110        self._break_down(usr)
111        self.top_most_labels = self._find_top_most_labels()
112        self.top_hole = self._find_top_hole()
113
114    def is_node(self, x):
115        """
116        Return true if x is a node (label or hole) in this semantic
117        representation.
118        """
119        return x in (self.labels | self.holes)
120
121    def _break_down(self, usr):
122        """
123        Extract holes, labels, formula fragments and constraints from the hole
124        semantics underspecified representation (USR).
125        """
126        if isinstance(usr, AndExpression):
127            self._break_down(usr.first)
128            self._break_down(usr.second)
129        elif isinstance(usr, ApplicationExpression):
130            func, args = usr.uncurry()
131            if func.variable.name == Constants.LEQ:
132                self.constraints.add(Constraint(args[0], args[1]))
133            elif func.variable.name == Constants.HOLE:
134                self.holes.add(args[0])
135            elif func.variable.name == Constants.LABEL:
136                self.labels.add(args[0])
137            else:
138                label = args[0]
139                assert label not in self.fragments
140                self.fragments[label] = (func, args[1:])
141        else:
142            raise ValueError(usr.label())
143
144    def _find_top_nodes(self, node_list):
145        top_nodes = node_list.copy()
146        for f in itervalues(self.fragments):
147            # the label is the first argument of the predicate
148            args = f[1]
149            for arg in args:
150                if arg in node_list:
151                    top_nodes.discard(arg)
152        return top_nodes
153
154    def _find_top_most_labels(self):
155        """
156        Return the set of labels which are not referenced directly as part of
157        another formula fragment.  These will be the top-most labels for the
158        subtree that they are part of.
159        """
160        return self._find_top_nodes(self.labels)
161
162    def _find_top_hole(self):
163        """
164        Return the hole that will be the top of the formula tree.
165        """
166        top_holes = self._find_top_nodes(self.holes)
167        assert len(top_holes) == 1  # it must be unique
168        return top_holes.pop()
169
170    def pluggings(self):
171        """
172        Calculate and return all the legal pluggings (mappings of labels to
173        holes) of this semantics given the constraints.
174        """
175        record = []
176        self._plug_nodes([(self.top_hole, [])], self.top_most_labels, {}, record)
177        return record
178
179    def _plug_nodes(self, queue, potential_labels, plug_acc, record):
180        """
181        Plug the nodes in `queue' with the labels in `potential_labels'.
182
183        Each element of `queue' is a tuple of the node to plug and the list of
184        ancestor holes from the root of the graph to that node.
185
186        `potential_labels' is a set of the labels which are still available for
187        plugging.
188
189        `plug_acc' is the incomplete mapping of holes to labels made on the
190        current branch of the search tree so far.
191
192        `record' is a list of all the complete pluggings that we have found in
193        total so far.  It is the only parameter that is destructively updated.
194        """
195        if queue != []:
196            (node, ancestors) = queue[0]
197            if node in self.holes:
198                # The node is a hole, try to plug it.
199                self._plug_hole(
200                    node, ancestors, queue[1:], potential_labels, plug_acc, record
201                )
202            else:
203                assert node in self.labels
204                # The node is a label.  Replace it in the queue by the holes and
205                # labels in the formula fragment named by that label.
206                args = self.fragments[node][1]
207                head = [(a, ancestors) for a in args if self.is_node(a)]
208                self._plug_nodes(head + queue[1:], potential_labels, plug_acc, record)
209        else:
210            raise Exception('queue empty')
211
212    def _plug_hole(self, hole, ancestors0, queue, potential_labels0, plug_acc0, record):
213        """
214        Try all possible ways of plugging a single hole.
215        See _plug_nodes for the meanings of the parameters.
216        """
217        # Add the current hole we're trying to plug into the list of ancestors.
218        assert hole not in ancestors0
219        ancestors = [hole] + ancestors0
220
221        # Try each potential label in this hole in turn.
222        for l in potential_labels0:
223            # Is the label valid in this hole?
224            if self._violates_constraints(l, ancestors):
225                continue
226
227            plug_acc = plug_acc0.copy()
228            plug_acc[hole] = l
229            potential_labels = potential_labels0.copy()
230            potential_labels.remove(l)
231
232            if len(potential_labels) == 0:
233                # No more potential labels.  That must mean all the holes have
234                # been filled so we have found a legal plugging so remember it.
235                #
236                # Note that the queue might not be empty because there might
237                # be labels on there that point to formula fragments with
238                # no holes in them.  _sanity_check_plugging will make sure
239                # all holes are filled.
240                self._sanity_check_plugging(plug_acc, self.top_hole, [])
241                record.append(plug_acc)
242            else:
243                # Recursively try to fill in the rest of the holes in the
244                # queue.  The label we just plugged into the hole could have
245                # holes of its own so at the end of the queue.  Putting it on
246                # the end of the queue gives us a breadth-first search, so that
247                # all the holes at level i of the formula tree are filled
248                # before filling level i+1.
249                # A depth-first search would work as well since the trees must
250                # be finite but the bookkeeping would be harder.
251                self._plug_nodes(
252                    queue + [(l, ancestors)], potential_labels, plug_acc, record
253                )
254
255    def _violates_constraints(self, label, ancestors):
256        """
257        Return True if the `label' cannot be placed underneath the holes given
258        by the set `ancestors' because it would violate the constraints imposed
259        on it.
260        """
261        for c in self.constraints:
262            if c.lhs == label:
263                if c.rhs not in ancestors:
264                    return True
265        return False
266
267    def _sanity_check_plugging(self, plugging, node, ancestors):
268        """
269        Make sure that a given plugging is legal.  We recursively go through
270        each node and make sure that no constraints are violated.
271        We also check that all holes have been filled.
272        """
273        if node in self.holes:
274            ancestors = [node] + ancestors
275            label = plugging[node]
276        else:
277            label = node
278        assert label in self.labels
279        for c in self.constraints:
280            if c.lhs == label:
281                assert c.rhs in ancestors
282        args = self.fragments[label][1]
283        for arg in args:
284            if self.is_node(arg):
285                self._sanity_check_plugging(plugging, arg, [label] + ancestors)
286
287    def formula_tree(self, plugging):
288        """
289        Return the first-order logic formula tree for this underspecified
290        representation using the plugging given.
291        """
292        return self._formula_tree(plugging, self.top_hole)
293
294    def _formula_tree(self, plugging, node):
295        if node in plugging:
296            return self._formula_tree(plugging, plugging[node])
297        elif node in self.fragments:
298            pred, args = self.fragments[node]
299            children = [self._formula_tree(plugging, arg) for arg in args]
300            return reduce(Constants.MAP[pred.variable.name], children)
301        else:
302            return node
303
304
305@compat.python_2_unicode_compatible
306class Constraint(object):
307    """
308    This class represents a constraint of the form (L =< N),
309    where L is a label and N is a node (a label or a hole).
310    """
311
312    def __init__(self, lhs, rhs):
313        self.lhs = lhs
314        self.rhs = rhs
315
316    def __eq__(self, other):
317        if self.__class__ == other.__class__:
318            return self.lhs == other.lhs and self.rhs == other.rhs
319        else:
320            return False
321
322    def __ne__(self, other):
323        return not (self == other)
324
325    def __hash__(self):
326        return hash(repr(self))
327
328    def __repr__(self):
329        return '(%s < %s)' % (self.lhs, self.rhs)
330
331
332def hole_readings(sentence, grammar_filename=None, verbose=False):
333    if not grammar_filename:
334        grammar_filename = 'grammars/sample_grammars/hole.fcfg'
335
336    if verbose:
337        print('Reading grammar file', grammar_filename)
338
339    parser = load_parser(grammar_filename)
340
341    # Parse the sentence.
342    tokens = sentence.split()
343    trees = list(parser.parse(tokens))
344    if verbose:
345        print('Got %d different parses' % len(trees))
346
347    all_readings = []
348    for tree in trees:
349        # Get the semantic feature from the top of the parse tree.
350        sem = tree.label()['SEM'].simplify()
351
352        # Print the raw semantic representation.
353        if verbose:
354            print('Raw:       ', sem)
355
356        # Skolemize away all quantifiers.  All variables become unique.
357        while isinstance(sem, LambdaExpression):
358            sem = sem.term
359        skolemized = skolemize(sem)
360
361        if verbose:
362            print('Skolemized:', skolemized)
363
364        # Break the hole semantics representation down into its components
365        # i.e. holes, labels, formula fragments and constraints.
366        hole_sem = HoleSemantics(skolemized)
367
368        # Maybe show the details of the semantic representation.
369        if verbose:
370            print('Holes:       ', hole_sem.holes)
371            print('Labels:      ', hole_sem.labels)
372            print('Constraints: ', hole_sem.constraints)
373            print('Top hole:    ', hole_sem.top_hole)
374            print('Top labels:  ', hole_sem.top_most_labels)
375            print('Fragments:')
376            for l, f in hole_sem.fragments.items():
377                print('\t%s: %s' % (l, f))
378
379        # Find all the possible ways to plug the formulas together.
380        pluggings = hole_sem.pluggings()
381
382        # Build FOL formula trees using the pluggings.
383        readings = list(map(hole_sem.formula_tree, pluggings))
384
385        # Print out the formulas in a textual format.
386        if verbose:
387            for i, r in enumerate(readings):
388                print()
389                print('%d. %s' % (i, r))
390            print()
391
392        all_readings.extend(readings)
393
394    return all_readings
395
396
397if __name__ == '__main__':
398    for r in hole_readings('a dog barks'):
399        print(r)
400    print()
401    for r in hole_readings('every girl chases a dog'):
402        print(r)
403