1# Natural Language Toolkit: Logic
2#
3# Author: Dan Garrette <dhgarrette@gmail.com>
4#
5# Copyright (C) 2001-2019 NLTK Project
6# URL: <http://nltk.org>
7# For license information, see LICENSE.TXT
8
9"""
10A version of first order predicate logic, built on
11top of the typed lambda calculus.
12"""
13from __future__ import print_function, unicode_literals
14
15import re
16import operator
17from collections import defaultdict
18from functools import reduce, total_ordering
19
20from six import string_types
21
22from nltk.util import Trie
23from nltk.internals import Counter
24from nltk.compat import python_2_unicode_compatible
25
26APP = 'APP'
27
28_counter = Counter()
29
30
31class Tokens(object):
32    LAMBDA = '\\'
33    LAMBDA_LIST = ['\\']
34
35    # Quantifiers
36    EXISTS = 'exists'
37    EXISTS_LIST = ['some', 'exists', 'exist']
38    ALL = 'all'
39    ALL_LIST = ['all', 'forall']
40
41    # Punctuation
42    DOT = '.'
43    OPEN = '('
44    CLOSE = ')'
45    COMMA = ','
46
47    # Operations
48    NOT = '-'
49    NOT_LIST = ['not', '-', '!']
50    AND = '&'
51    AND_LIST = ['and', '&', '^']
52    OR = '|'
53    OR_LIST = ['or', '|']
54    IMP = '->'
55    IMP_LIST = ['implies', '->', '=>']
56    IFF = '<->'
57    IFF_LIST = ['iff', '<->', '<=>']
58    EQ = '='
59    EQ_LIST = ['=', '==']
60    NEQ = '!='
61    NEQ_LIST = ['!=']
62
63    # Collections of tokens
64    BINOPS = AND_LIST + OR_LIST + IMP_LIST + IFF_LIST
65    QUANTS = EXISTS_LIST + ALL_LIST
66    PUNCT = [DOT, OPEN, CLOSE, COMMA]
67
68    TOKENS = BINOPS + EQ_LIST + NEQ_LIST + QUANTS + LAMBDA_LIST + PUNCT + NOT_LIST
69
70    # Special
71    SYMBOLS = [x for x in TOKENS if re.match(r'^[-\\.(),!&^|>=<]*$', x)]
72
73
74def boolean_ops():
75    """
76    Boolean operators
77    """
78    names = ["negation", "conjunction", "disjunction", "implication", "equivalence"]
79    for pair in zip(names, [Tokens.NOT, Tokens.AND, Tokens.OR, Tokens.IMP, Tokens.IFF]):
80        print("%-15s\t%s" % pair)
81
82
83def equality_preds():
84    """
85    Equality predicates
86    """
87    names = ["equality", "inequality"]
88    for pair in zip(names, [Tokens.EQ, Tokens.NEQ]):
89        print("%-15s\t%s" % pair)
90
91
92def binding_ops():
93    """
94    Binding operators
95    """
96    names = ["existential", "universal", "lambda"]
97    for pair in zip(names, [Tokens.EXISTS, Tokens.ALL, Tokens.LAMBDA]):
98        print("%-15s\t%s" % pair)
99
100
101@python_2_unicode_compatible
102class LogicParser(object):
103    """A lambda calculus expression parser."""
104
105    def __init__(self, type_check=False):
106        """
107        :param type_check: bool should type checking be performed?
108        to their types.
109        """
110        assert isinstance(type_check, bool)
111
112        self._currentIndex = 0
113        self._buffer = []
114        self.type_check = type_check
115
116        """A list of tuples of quote characters.  The 4-tuple is comprised
117        of the start character, the end character, the escape character, and
118        a boolean indicating whether the quotes should be included in the
119        result. Quotes are used to signify that a token should be treated as
120        atomic, ignoring any special characters within the token.  The escape
121        character allows the quote end character to be used within the quote.
122        If True, the boolean indicates that the final token should contain the
123        quote and escape characters.
124        This method exists to be overridden"""
125        self.quote_chars = []
126
127        self.operator_precedence = dict(
128            [(x, 1) for x in Tokens.LAMBDA_LIST]
129            + [(x, 2) for x in Tokens.NOT_LIST]
130            + [(APP, 3)]
131            + [(x, 4) for x in Tokens.EQ_LIST + Tokens.NEQ_LIST]
132            + [(x, 5) for x in Tokens.QUANTS]
133            + [(x, 6) for x in Tokens.AND_LIST]
134            + [(x, 7) for x in Tokens.OR_LIST]
135            + [(x, 8) for x in Tokens.IMP_LIST]
136            + [(x, 9) for x in Tokens.IFF_LIST]
137            + [(None, 10)]
138        )
139        self.right_associated_operations = [APP]
140
141    def parse(self, data, signature=None):
142        """
143        Parse the expression.
144
145        :param data: str for the input to be parsed
146        :param signature: ``dict<str, str>`` that maps variable names to type
147        strings
148        :returns: a parsed Expression
149        """
150        data = data.rstrip()
151
152        self._currentIndex = 0
153        self._buffer, mapping = self.process(data)
154
155        try:
156            result = self.process_next_expression(None)
157            if self.inRange(0):
158                raise UnexpectedTokenException(self._currentIndex + 1, self.token(0))
159        except LogicalExpressionException as e:
160            msg = '%s\n%s\n%s^' % (e, data, ' ' * mapping[e.index - 1])
161            raise LogicalExpressionException(None, msg)
162
163        if self.type_check:
164            result.typecheck(signature)
165
166        return result
167
168    def process(self, data):
169        """Split the data into tokens"""
170        out = []
171        mapping = {}
172        tokenTrie = Trie(self.get_all_symbols())
173        token = ''
174        data_idx = 0
175        token_start_idx = data_idx
176        while data_idx < len(data):
177            cur_data_idx = data_idx
178            quoted_token, data_idx = self.process_quoted_token(data_idx, data)
179            if quoted_token:
180                if not token:
181                    token_start_idx = cur_data_idx
182                token += quoted_token
183                continue
184
185            st = tokenTrie
186            c = data[data_idx]
187            symbol = ''
188            while c in st:
189                symbol += c
190                st = st[c]
191                if len(data) - data_idx > len(symbol):
192                    c = data[data_idx + len(symbol)]
193                else:
194                    break
195            if Trie.LEAF in st:
196                # token is a complete symbol
197                if token:
198                    mapping[len(out)] = token_start_idx
199                    out.append(token)
200                    token = ''
201                mapping[len(out)] = data_idx
202                out.append(symbol)
203                data_idx += len(symbol)
204            else:
205                if data[data_idx] in ' \t\n':  # any whitespace
206                    if token:
207                        mapping[len(out)] = token_start_idx
208                        out.append(token)
209                        token = ''
210                else:
211                    if not token:
212                        token_start_idx = data_idx
213                    token += data[data_idx]
214                data_idx += 1
215        if token:
216            mapping[len(out)] = token_start_idx
217            out.append(token)
218        mapping[len(out)] = len(data)
219        mapping[len(out) + 1] = len(data) + 1
220        return out, mapping
221
222    def process_quoted_token(self, data_idx, data):
223        token = ''
224        c = data[data_idx]
225        i = data_idx
226        for start, end, escape, incl_quotes in self.quote_chars:
227            if c == start:
228                if incl_quotes:
229                    token += c
230                i += 1
231                while data[i] != end:
232                    if data[i] == escape:
233                        if incl_quotes:
234                            token += data[i]
235                        i += 1
236                        if len(data) == i:  # if there are no more chars
237                            raise LogicalExpressionException(
238                                None,
239                                "End of input reached.  "
240                                "Escape character [%s] found at end." % escape,
241                            )
242                        token += data[i]
243                    else:
244                        token += data[i]
245                    i += 1
246                    if len(data) == i:
247                        raise LogicalExpressionException(
248                            None, "End of input reached.  " "Expected: [%s]" % end
249                        )
250                if incl_quotes:
251                    token += data[i]
252                i += 1
253                if not token:
254                    raise LogicalExpressionException(None, 'Empty quoted token found')
255                break
256        return token, i
257
258    def get_all_symbols(self):
259        """This method exists to be overridden"""
260        return Tokens.SYMBOLS
261
262    def inRange(self, location):
263        """Return TRUE if the given location is within the buffer"""
264        return self._currentIndex + location < len(self._buffer)
265
266    def token(self, location=None):
267        """Get the next waiting token.  If a location is given, then
268        return the token at currentIndex+location without advancing
269        currentIndex; setting it gives lookahead/lookback capability."""
270        try:
271            if location is None:
272                tok = self._buffer[self._currentIndex]
273                self._currentIndex += 1
274            else:
275                tok = self._buffer[self._currentIndex + location]
276            return tok
277        except IndexError:
278            raise ExpectedMoreTokensException(self._currentIndex + 1)
279
280    def isvariable(self, tok):
281        return tok not in Tokens.TOKENS
282
283    def process_next_expression(self, context):
284        """Parse the next complete expression from the stream and return it."""
285        try:
286            tok = self.token()
287        except ExpectedMoreTokensException:
288            raise ExpectedMoreTokensException(
289                self._currentIndex + 1, message='Expression expected.'
290            )
291
292        accum = self.handle(tok, context)
293
294        if not accum:
295            raise UnexpectedTokenException(
296                self._currentIndex, tok, message='Expression expected.'
297            )
298
299        return self.attempt_adjuncts(accum, context)
300
301    def handle(self, tok, context):
302        """This method is intended to be overridden for logics that
303        use different operators or expressions"""
304        if self.isvariable(tok):
305            return self.handle_variable(tok, context)
306
307        elif tok in Tokens.NOT_LIST:
308            return self.handle_negation(tok, context)
309
310        elif tok in Tokens.LAMBDA_LIST:
311            return self.handle_lambda(tok, context)
312
313        elif tok in Tokens.QUANTS:
314            return self.handle_quant(tok, context)
315
316        elif tok == Tokens.OPEN:
317            return self.handle_open(tok, context)
318
319    def attempt_adjuncts(self, expression, context):
320        cur_idx = None
321        while cur_idx != self._currentIndex:  # while adjuncts are added
322            cur_idx = self._currentIndex
323            expression = self.attempt_EqualityExpression(expression, context)
324            expression = self.attempt_ApplicationExpression(expression, context)
325            expression = self.attempt_BooleanExpression(expression, context)
326        return expression
327
328    def handle_negation(self, tok, context):
329        return self.make_NegatedExpression(self.process_next_expression(Tokens.NOT))
330
331    def make_NegatedExpression(self, expression):
332        return NegatedExpression(expression)
333
334    def handle_variable(self, tok, context):
335        # It's either: 1) a predicate expression: sees(x,y)
336        #             2) an application expression: P(x)
337        #             3) a solo variable: john OR x
338        accum = self.make_VariableExpression(tok)
339        if self.inRange(0) and self.token(0) == Tokens.OPEN:
340            # The predicate has arguments
341            if not isinstance(accum, FunctionVariableExpression) and not isinstance(
342                accum, ConstantExpression
343            ):
344                raise LogicalExpressionException(
345                    self._currentIndex,
346                    "'%s' is an illegal predicate name.  "
347                    "Individual variables may not be used as "
348                    "predicates." % tok,
349                )
350            self.token()  # swallow the Open Paren
351
352            # curry the arguments
353            accum = self.make_ApplicationExpression(
354                accum, self.process_next_expression(APP)
355            )
356            while self.inRange(0) and self.token(0) == Tokens.COMMA:
357                self.token()  # swallow the comma
358                accum = self.make_ApplicationExpression(
359                    accum, self.process_next_expression(APP)
360                )
361            self.assertNextToken(Tokens.CLOSE)
362        return accum
363
364    def get_next_token_variable(self, description):
365        try:
366            tok = self.token()
367        except ExpectedMoreTokensException as e:
368            raise ExpectedMoreTokensException(e.index, 'Variable expected.')
369        if isinstance(self.make_VariableExpression(tok), ConstantExpression):
370            raise LogicalExpressionException(
371                self._currentIndex,
372                "'%s' is an illegal variable name.  "
373                "Constants may not be %s." % (tok, description),
374            )
375        return Variable(tok)
376
377    def handle_lambda(self, tok, context):
378        # Expression is a lambda expression
379        if not self.inRange(0):
380            raise ExpectedMoreTokensException(
381                self._currentIndex + 2,
382                message="Variable and Expression expected following lambda operator.",
383            )
384        vars = [self.get_next_token_variable('abstracted')]
385        while True:
386            if not self.inRange(0) or (
387                self.token(0) == Tokens.DOT and not self.inRange(1)
388            ):
389                raise ExpectedMoreTokensException(
390                    self._currentIndex + 2, message="Expression expected."
391                )
392            if not self.isvariable(self.token(0)):
393                break
394            # Support expressions like: \x y.M == \x.\y.M
395            vars.append(self.get_next_token_variable('abstracted'))
396        if self.inRange(0) and self.token(0) == Tokens.DOT:
397            self.token()  # swallow the dot
398
399        accum = self.process_next_expression(tok)
400        while vars:
401            accum = self.make_LambdaExpression(vars.pop(), accum)
402        return accum
403
404    def handle_quant(self, tok, context):
405        # Expression is a quantified expression: some x.M
406        factory = self.get_QuantifiedExpression_factory(tok)
407
408        if not self.inRange(0):
409            raise ExpectedMoreTokensException(
410                self._currentIndex + 2,
411                message="Variable and Expression expected following quantifier '%s'."
412                % tok,
413            )
414        vars = [self.get_next_token_variable('quantified')]
415        while True:
416            if not self.inRange(0) or (
417                self.token(0) == Tokens.DOT and not self.inRange(1)
418            ):
419                raise ExpectedMoreTokensException(
420                    self._currentIndex + 2, message="Expression expected."
421                )
422            if not self.isvariable(self.token(0)):
423                break
424            # Support expressions like: some x y.M == some x.some y.M
425            vars.append(self.get_next_token_variable('quantified'))
426        if self.inRange(0) and self.token(0) == Tokens.DOT:
427            self.token()  # swallow the dot
428
429        accum = self.process_next_expression(tok)
430        while vars:
431            accum = self.make_QuanifiedExpression(factory, vars.pop(), accum)
432        return accum
433
434    def get_QuantifiedExpression_factory(self, tok):
435        """This method serves as a hook for other logic parsers that
436        have different quantifiers"""
437        if tok in Tokens.EXISTS_LIST:
438            return ExistsExpression
439        elif tok in Tokens.ALL_LIST:
440            return AllExpression
441        else:
442            self.assertToken(tok, Tokens.QUANTS)
443
444    def make_QuanifiedExpression(self, factory, variable, term):
445        return factory(variable, term)
446
447    def handle_open(self, tok, context):
448        # Expression is in parens
449        accum = self.process_next_expression(None)
450        self.assertNextToken(Tokens.CLOSE)
451        return accum
452
453    def attempt_EqualityExpression(self, expression, context):
454        """Attempt to make an equality expression.  If the next token is an
455        equality operator, then an EqualityExpression will be returned.
456        Otherwise, the parameter will be returned."""
457        if self.inRange(0):
458            tok = self.token(0)
459            if tok in Tokens.EQ_LIST + Tokens.NEQ_LIST and self.has_priority(
460                tok, context
461            ):
462                self.token()  # swallow the "=" or "!="
463                expression = self.make_EqualityExpression(
464                    expression, self.process_next_expression(tok)
465                )
466                if tok in Tokens.NEQ_LIST:
467                    expression = self.make_NegatedExpression(expression)
468        return expression
469
470    def make_EqualityExpression(self, first, second):
471        """This method serves as a hook for other logic parsers that
472        have different equality expression classes"""
473        return EqualityExpression(first, second)
474
475    def attempt_BooleanExpression(self, expression, context):
476        """Attempt to make a boolean expression.  If the next token is a boolean
477        operator, then a BooleanExpression will be returned.  Otherwise, the
478        parameter will be returned."""
479        while self.inRange(0):
480            tok = self.token(0)
481            factory = self.get_BooleanExpression_factory(tok)
482            if factory and self.has_priority(tok, context):
483                self.token()  # swallow the operator
484                expression = self.make_BooleanExpression(
485                    factory, expression, self.process_next_expression(tok)
486                )
487            else:
488                break
489        return expression
490
491    def get_BooleanExpression_factory(self, tok):
492        """This method serves as a hook for other logic parsers that
493        have different boolean operators"""
494        if tok in Tokens.AND_LIST:
495            return AndExpression
496        elif tok in Tokens.OR_LIST:
497            return OrExpression
498        elif tok in Tokens.IMP_LIST:
499            return ImpExpression
500        elif tok in Tokens.IFF_LIST:
501            return IffExpression
502        else:
503            return None
504
505    def make_BooleanExpression(self, factory, first, second):
506        return factory(first, second)
507
508    def attempt_ApplicationExpression(self, expression, context):
509        """Attempt to make an application expression.  The next tokens are
510        a list of arguments in parens, then the argument expression is a
511        function being applied to the arguments.  Otherwise, return the
512        argument expression."""
513        if self.has_priority(APP, context):
514            if self.inRange(0) and self.token(0) == Tokens.OPEN:
515                if (
516                    not isinstance(expression, LambdaExpression)
517                    and not isinstance(expression, ApplicationExpression)
518                    and not isinstance(expression, FunctionVariableExpression)
519                    and not isinstance(expression, ConstantExpression)
520                ):
521                    raise LogicalExpressionException(
522                        self._currentIndex,
523                        ("The function '%s" % expression)
524                        + "' is not a Lambda Expression, an "
525                        "Application Expression, or a "
526                        "functional predicate, so it may "
527                        "not take arguments.",
528                    )
529                self.token()  # swallow then open paren
530                # curry the arguments
531                accum = self.make_ApplicationExpression(
532                    expression, self.process_next_expression(APP)
533                )
534                while self.inRange(0) and self.token(0) == Tokens.COMMA:
535                    self.token()  # swallow the comma
536                    accum = self.make_ApplicationExpression(
537                        accum, self.process_next_expression(APP)
538                    )
539                self.assertNextToken(Tokens.CLOSE)
540                return accum
541        return expression
542
543    def make_ApplicationExpression(self, function, argument):
544        return ApplicationExpression(function, argument)
545
546    def make_VariableExpression(self, name):
547        return VariableExpression(Variable(name))
548
549    def make_LambdaExpression(self, variable, term):
550        return LambdaExpression(variable, term)
551
552    def has_priority(self, operation, context):
553        return self.operator_precedence[operation] < self.operator_precedence[
554            context
555        ] or (
556            operation in self.right_associated_operations
557            and self.operator_precedence[operation] == self.operator_precedence[context]
558        )
559
560    def assertNextToken(self, expected):
561        try:
562            tok = self.token()
563        except ExpectedMoreTokensException as e:
564            raise ExpectedMoreTokensException(
565                e.index, message="Expected token '%s'." % expected
566            )
567
568        if isinstance(expected, list):
569            if tok not in expected:
570                raise UnexpectedTokenException(self._currentIndex, tok, expected)
571        else:
572            if tok != expected:
573                raise UnexpectedTokenException(self._currentIndex, tok, expected)
574
575    def assertToken(self, tok, expected):
576        if isinstance(expected, list):
577            if tok not in expected:
578                raise UnexpectedTokenException(self._currentIndex, tok, expected)
579        else:
580            if tok != expected:
581                raise UnexpectedTokenException(self._currentIndex, tok, expected)
582
583    def __repr__(self):
584        if self.inRange(0):
585            msg = 'Next token: ' + self.token(0)
586        else:
587            msg = 'No more tokens'
588        return '<' + self.__class__.__name__ + ': ' + msg + '>'
589
590
591def read_logic(s, logic_parser=None, encoding=None):
592    """
593    Convert a file of First Order Formulas into a list of {Expression}s.
594
595    :param s: the contents of the file
596    :type s: str
597    :param logic_parser: The parser to be used to parse the logical expression
598    :type logic_parser: LogicParser
599    :param encoding: the encoding of the input string, if it is binary
600    :type encoding: str
601    :return: a list of parsed formulas.
602    :rtype: list(Expression)
603    """
604    if encoding is not None:
605        s = s.decode(encoding)
606    if logic_parser is None:
607        logic_parser = LogicParser()
608
609    statements = []
610    for linenum, line in enumerate(s.splitlines()):
611        line = line.strip()
612        if line.startswith('#') or line == '':
613            continue
614        try:
615            statements.append(logic_parser.parse(line))
616        except LogicalExpressionException:
617            raise ValueError('Unable to parse line %s: %s' % (linenum, line))
618    return statements
619
620
621@total_ordering
622@python_2_unicode_compatible
623class Variable(object):
624    def __init__(self, name):
625        """
626        :param name: the name of the variable
627        """
628        assert isinstance(name, string_types), "%s is not a string" % name
629        self.name = name
630
631    def __eq__(self, other):
632        return isinstance(other, Variable) and self.name == other.name
633
634    def __ne__(self, other):
635        return not self == other
636
637    def __lt__(self, other):
638        if not isinstance(other, Variable):
639            raise TypeError
640        return self.name < other.name
641
642    def substitute_bindings(self, bindings):
643        return bindings.get(self, self)
644
645    def __hash__(self):
646        return hash(self.name)
647
648    def __str__(self):
649        return self.name
650
651    def __repr__(self):
652        return "Variable('%s')" % self.name
653
654
655def unique_variable(pattern=None, ignore=None):
656    """
657    Return a new, unique variable.
658
659    :param pattern: ``Variable`` that is being replaced.  The new variable must
660        be the same type.
661    :param term: a set of ``Variable`` objects that should not be returned from
662        this function.
663    :rtype: Variable
664    """
665    if pattern is not None:
666        if is_indvar(pattern.name):
667            prefix = 'z'
668        elif is_funcvar(pattern.name):
669            prefix = 'F'
670        elif is_eventvar(pattern.name):
671            prefix = 'e0'
672        else:
673            assert False, "Cannot generate a unique constant"
674    else:
675        prefix = 'z'
676
677    v = Variable("%s%s" % (prefix, _counter.get()))
678    while ignore is not None and v in ignore:
679        v = Variable("%s%s" % (prefix, _counter.get()))
680    return v
681
682
683def skolem_function(univ_scope=None):
684    """
685    Return a skolem function over the variables in univ_scope
686    param univ_scope
687    """
688    skolem = VariableExpression(Variable('F%s' % _counter.get()))
689    if univ_scope:
690        for v in list(univ_scope):
691            skolem = skolem(VariableExpression(v))
692    return skolem
693
694
695@python_2_unicode_compatible
696class Type(object):
697    def __repr__(self):
698        return "%s" % self
699
700    def __hash__(self):
701        return hash("%s" % self)
702
703    @classmethod
704    def fromstring(cls, s):
705        return read_type(s)
706
707
708@python_2_unicode_compatible
709class ComplexType(Type):
710    def __init__(self, first, second):
711        assert isinstance(first, Type), "%s is not a Type" % first
712        assert isinstance(second, Type), "%s is not a Type" % second
713        self.first = first
714        self.second = second
715
716    def __eq__(self, other):
717        return (
718            isinstance(other, ComplexType)
719            and self.first == other.first
720            and self.second == other.second
721        )
722
723    def __ne__(self, other):
724        return not self == other
725
726    __hash__ = Type.__hash__
727
728    def matches(self, other):
729        if isinstance(other, ComplexType):
730            return self.first.matches(other.first) and self.second.matches(other.second)
731        else:
732            return self == ANY_TYPE
733
734    def resolve(self, other):
735        if other == ANY_TYPE:
736            return self
737        elif isinstance(other, ComplexType):
738            f = self.first.resolve(other.first)
739            s = self.second.resolve(other.second)
740            if f and s:
741                return ComplexType(f, s)
742            else:
743                return None
744        elif self == ANY_TYPE:
745            return other
746        else:
747            return None
748
749    def __str__(self):
750        if self == ANY_TYPE:
751            return "%s" % ANY_TYPE
752        else:
753            return '<%s,%s>' % (self.first, self.second)
754
755    def str(self):
756        if self == ANY_TYPE:
757            return ANY_TYPE.str()
758        else:
759            return '(%s -> %s)' % (self.first.str(), self.second.str())
760
761
762class BasicType(Type):
763    def __eq__(self, other):
764        return isinstance(other, BasicType) and ("%s" % self) == ("%s" % other)
765
766    def __ne__(self, other):
767        return not self == other
768
769    __hash__ = Type.__hash__
770
771    def matches(self, other):
772        return other == ANY_TYPE or self == other
773
774    def resolve(self, other):
775        if self.matches(other):
776            return self
777        else:
778            return None
779
780
781@python_2_unicode_compatible
782class EntityType(BasicType):
783    def __str__(self):
784        return 'e'
785
786    def str(self):
787        return 'IND'
788
789
790@python_2_unicode_compatible
791class TruthValueType(BasicType):
792    def __str__(self):
793        return 't'
794
795    def str(self):
796        return 'BOOL'
797
798
799@python_2_unicode_compatible
800class EventType(BasicType):
801    def __str__(self):
802        return 'v'
803
804    def str(self):
805        return 'EVENT'
806
807
808@python_2_unicode_compatible
809class AnyType(BasicType, ComplexType):
810    def __init__(self):
811        pass
812
813    @property
814    def first(self):
815        return self
816
817    @property
818    def second(self):
819        return self
820
821    def __eq__(self, other):
822        return isinstance(other, AnyType) or other.__eq__(self)
823
824    def __ne__(self, other):
825        return not self == other
826
827    __hash__ = Type.__hash__
828
829    def matches(self, other):
830        return True
831
832    def resolve(self, other):
833        return other
834
835    def __str__(self):
836        return '?'
837
838    def str(self):
839        return 'ANY'
840
841
842TRUTH_TYPE = TruthValueType()
843ENTITY_TYPE = EntityType()
844EVENT_TYPE = EventType()
845ANY_TYPE = AnyType()
846
847
848def read_type(type_string):
849    assert isinstance(type_string, string_types)
850    type_string = type_string.replace(' ', '')  # remove spaces
851
852    if type_string[0] == '<':
853        assert type_string[-1] == '>'
854        paren_count = 0
855        for i, char in enumerate(type_string):
856            if char == '<':
857                paren_count += 1
858            elif char == '>':
859                paren_count -= 1
860                assert paren_count > 0
861            elif char == ',':
862                if paren_count == 1:
863                    break
864        return ComplexType(
865            read_type(type_string[1:i]), read_type(type_string[i + 1 : -1])
866        )
867    elif type_string[0] == "%s" % ENTITY_TYPE:
868        return ENTITY_TYPE
869    elif type_string[0] == "%s" % TRUTH_TYPE:
870        return TRUTH_TYPE
871    elif type_string[0] == "%s" % ANY_TYPE:
872        return ANY_TYPE
873    else:
874        raise LogicalExpressionException(None, "Unexpected character: '%s'." % type_string[0])
875
876
877class TypeException(Exception):
878    def __init__(self, msg):
879        super(TypeException, self).__init__(msg)
880
881
882class InconsistentTypeHierarchyException(TypeException):
883    def __init__(self, variable, expression=None):
884        if expression:
885            msg = (
886                "The variable '%s' was found in multiple places with different"
887                " types in '%s'." % (variable, expression)
888            )
889        else:
890            msg = (
891                "The variable '%s' was found in multiple places with different"
892                " types." % (variable)
893            )
894        super(InconsistentTypeHierarchyException, self).__init__(msg)
895
896
897class TypeResolutionException(TypeException):
898    def __init__(self, expression, other_type):
899        super(TypeResolutionException, self).__init__(
900            "The type of '%s', '%s', cannot be resolved with type '%s'"
901            % (expression, expression.type, other_type)
902        )
903
904
905class IllegalTypeException(TypeException):
906    def __init__(self, expression, other_type, allowed_type):
907        super(IllegalTypeException, self).__init__(
908            "Cannot set type of %s '%s' to '%s'; must match type '%s'."
909            % (expression.__class__.__name__, expression, other_type, allowed_type)
910        )
911
912
913def typecheck(expressions, signature=None):
914    """
915    Ensure correct typing across a collection of ``Expression`` objects.
916    :param expressions: a collection of expressions
917    :param signature: dict that maps variable names to types (or string
918    representations of types)
919    """
920    # typecheck and create master signature
921    for expression in expressions:
922        signature = expression.typecheck(signature)
923    # apply master signature to all expressions
924    for expression in expressions[:-1]:
925        expression.typecheck(signature)
926    return signature
927
928
929class SubstituteBindingsI(object):
930    """
931    An interface for classes that can perform substitutions for
932    variables.
933    """
934
935    def substitute_bindings(self, bindings):
936        """
937        :return: The object that is obtained by replacing
938            each variable bound by ``bindings`` with its values.
939            Aliases are already resolved. (maybe?)
940        :rtype: (any)
941        """
942        raise NotImplementedError()
943
944    def variables(self):
945        """
946        :return: A list of all variables in this object.
947        """
948        raise NotImplementedError()
949
950
951@python_2_unicode_compatible
952class Expression(SubstituteBindingsI):
953    """This is the base abstract object for all logical expressions"""
954
955    _logic_parser = LogicParser()
956    _type_checking_logic_parser = LogicParser(type_check=True)
957
958    @classmethod
959    def fromstring(cls, s, type_check=False, signature=None):
960        if type_check:
961            return cls._type_checking_logic_parser.parse(s, signature)
962        else:
963            return cls._logic_parser.parse(s, signature)
964
965    def __call__(self, other, *additional):
966        accum = self.applyto(other)
967        for a in additional:
968            accum = accum(a)
969        return accum
970
971    def applyto(self, other):
972        assert isinstance(other, Expression), "%s is not an Expression" % other
973        return ApplicationExpression(self, other)
974
975    def __neg__(self):
976        return NegatedExpression(self)
977
978    def negate(self):
979        """If this is a negated expression, remove the negation.
980        Otherwise add a negation."""
981        return -self
982
983    def __and__(self, other):
984        if not isinstance(other, Expression):
985            raise TypeError("%s is not an Expression" % other)
986        return AndExpression(self, other)
987
988    def __or__(self, other):
989        if not isinstance(other, Expression):
990            raise TypeError("%s is not an Expression" % other)
991        return OrExpression(self, other)
992
993    def __gt__(self, other):
994        if not isinstance(other, Expression):
995            raise TypeError("%s is not an Expression" % other)
996        return ImpExpression(self, other)
997
998    def __lt__(self, other):
999        if not isinstance(other, Expression):
1000            raise TypeError("%s is not an Expression" % other)
1001        return IffExpression(self, other)
1002
1003    def __eq__(self, other):
1004        raise NotImplementedError()
1005
1006    def __ne__(self, other):
1007        return not self == other
1008
1009    def equiv(self, other, prover=None):
1010        """
1011        Check for logical equivalence.
1012        Pass the expression (self <-> other) to the theorem prover.
1013        If the prover says it is valid, then the self and other are equal.
1014
1015        :param other: an ``Expression`` to check equality against
1016        :param prover: a ``nltk.inference.api.Prover``
1017        """
1018        assert isinstance(other, Expression), "%s is not an Expression" % other
1019
1020        if prover is None:
1021            from nltk.inference import Prover9
1022
1023            prover = Prover9()
1024        bicond = IffExpression(self.simplify(), other.simplify())
1025        return prover.prove(bicond)
1026
1027    def __hash__(self):
1028        return hash(repr(self))
1029
1030    def substitute_bindings(self, bindings):
1031        expr = self
1032        for var in expr.variables():
1033            if var in bindings:
1034                val = bindings[var]
1035                if isinstance(val, Variable):
1036                    val = self.make_VariableExpression(val)
1037                elif not isinstance(val, Expression):
1038                    raise ValueError(
1039                        'Can not substitute a non-expression '
1040                        'value into an expression: %r' % (val,)
1041                    )
1042                # Substitute bindings in the target value.
1043                val = val.substitute_bindings(bindings)
1044                # Replace var w/ the target value.
1045                expr = expr.replace(var, val)
1046        return expr.simplify()
1047
1048    def typecheck(self, signature=None):
1049        """
1050        Infer and check types.  Raise exceptions if necessary.
1051
1052        :param signature: dict that maps variable names to types (or string
1053            representations of types)
1054        :return: the signature, plus any additional type mappings
1055        """
1056        sig = defaultdict(list)
1057        if signature:
1058            for key in signature:
1059                val = signature[key]
1060                varEx = VariableExpression(Variable(key))
1061                if isinstance(val, Type):
1062                    varEx.type = val
1063                else:
1064                    varEx.type = read_type(val)
1065                sig[key].append(varEx)
1066
1067        self._set_type(signature=sig)
1068
1069        return dict((key, sig[key][0].type) for key in sig)
1070
1071    def findtype(self, variable):
1072        """
1073        Find the type of the given variable as it is used in this expression.
1074        For example, finding the type of "P" in "P(x) & Q(x,y)" yields "<e,t>"
1075
1076        :param variable: Variable
1077        """
1078        raise NotImplementedError()
1079
1080    def _set_type(self, other_type=ANY_TYPE, signature=None):
1081        """
1082        Set the type of this expression to be the given type.  Raise type
1083        exceptions where applicable.
1084
1085        :param other_type: Type
1086        :param signature: dict(str -> list(AbstractVariableExpression))
1087        """
1088        raise NotImplementedError()
1089
1090    def replace(self, variable, expression, replace_bound=False, alpha_convert=True):
1091        """
1092        Replace every instance of 'variable' with 'expression'
1093        :param variable: ``Variable`` The variable to replace
1094        :param expression: ``Expression`` The expression with which to replace it
1095        :param replace_bound: bool Should bound variables be replaced?
1096        :param alpha_convert: bool Alpha convert automatically to avoid name clashes?
1097        """
1098        assert isinstance(variable, Variable), "%s is not a Variable" % variable
1099        assert isinstance(expression, Expression), (
1100            "%s is not an Expression" % expression
1101        )
1102
1103        return self.visit_structured(
1104            lambda e: e.replace(variable, expression, replace_bound, alpha_convert),
1105            self.__class__,
1106        )
1107
1108    def normalize(self, newvars=None):
1109        """Rename auto-generated unique variables"""
1110
1111        def get_indiv_vars(e):
1112            if isinstance(e, IndividualVariableExpression):
1113                return set([e])
1114            elif isinstance(e, AbstractVariableExpression):
1115                return set()
1116            else:
1117                return e.visit(
1118                    get_indiv_vars, lambda parts: reduce(operator.or_, parts, set())
1119                )
1120
1121        result = self
1122        for i, e in enumerate(sorted(get_indiv_vars(self), key=lambda e: e.variable)):
1123            if isinstance(e, EventVariableExpression):
1124                newVar = e.__class__(Variable('e0%s' % (i + 1)))
1125            elif isinstance(e, IndividualVariableExpression):
1126                newVar = e.__class__(Variable('z%s' % (i + 1)))
1127            else:
1128                newVar = e
1129            result = result.replace(e.variable, newVar, True)
1130        return result
1131
1132    def visit(self, function, combinator):
1133        """
1134        Recursively visit subexpressions.  Apply 'function' to each
1135        subexpression and pass the result of each function application
1136        to the 'combinator' for aggregation:
1137
1138            return combinator(map(function, self.subexpressions))
1139
1140        Bound variables are neither applied upon by the function nor given to
1141        the combinator.
1142        :param function: ``Function<Expression,T>`` to call on each subexpression
1143        :param combinator: ``Function<list<T>,R>`` to combine the results of the
1144        function calls
1145        :return: result of combination ``R``
1146        """
1147        raise NotImplementedError()
1148
1149    def visit_structured(self, function, combinator):
1150        """
1151        Recursively visit subexpressions.  Apply 'function' to each
1152        subexpression and pass the result of each function application
1153        to the 'combinator' for aggregation.  The combinator must have
1154        the same signature as the constructor.  The function is not
1155        applied to bound variables, but they are passed to the
1156        combinator.
1157        :param function: ``Function`` to call on each subexpression
1158        :param combinator: ``Function`` with the same signature as the
1159        constructor, to combine the results of the function calls
1160        :return: result of combination
1161        """
1162        return self.visit(function, lambda parts: combinator(*parts))
1163
1164    def __repr__(self):
1165        return '<%s %s>' % (self.__class__.__name__, self)
1166
1167    def __str__(self):
1168        return self.str()
1169
1170    def variables(self):
1171        """
1172        Return a set of all the variables for binding substitution.
1173        The variables returned include all free (non-bound) individual
1174        variables and any variable starting with '?' or '@'.
1175        :return: set of ``Variable`` objects
1176        """
1177        return self.free() | set(
1178            p for p in self.predicates() | self.constants() if re.match('^[?@]', p.name)
1179        )
1180
1181    def free(self):
1182        """
1183        Return a set of all the free (non-bound) variables.  This includes
1184        both individual and predicate variables, but not constants.
1185        :return: set of ``Variable`` objects
1186        """
1187        return self.visit(
1188            lambda e: e.free(), lambda parts: reduce(operator.or_, parts, set())
1189        )
1190
1191    def constants(self):
1192        """
1193        Return a set of individual constants (non-predicates).
1194        :return: set of ``Variable`` objects
1195        """
1196        return self.visit(
1197            lambda e: e.constants(), lambda parts: reduce(operator.or_, parts, set())
1198        )
1199
1200    def predicates(self):
1201        """
1202        Return a set of predicates (constants, not variables).
1203        :return: set of ``Variable`` objects
1204        """
1205        return self.visit(
1206            lambda e: e.predicates(), lambda parts: reduce(operator.or_, parts, set())
1207        )
1208
1209    def simplify(self):
1210        """
1211        :return: beta-converted version of this expression
1212        """
1213        return self.visit_structured(lambda e: e.simplify(), self.__class__)
1214
1215    def make_VariableExpression(self, variable):
1216        return VariableExpression(variable)
1217
1218
1219@python_2_unicode_compatible
1220class ApplicationExpression(Expression):
1221    r"""
1222    This class is used to represent two related types of logical expressions.
1223
1224    The first is a Predicate Expression, such as "P(x,y)".  A predicate
1225    expression is comprised of a ``FunctionVariableExpression`` or
1226    ``ConstantExpression`` as the predicate and a list of Expressions as the
1227    arguments.
1228
1229    The second is a an application of one expression to another, such as
1230    "(\x.dog(x))(fido)".
1231
1232    The reason Predicate Expressions are treated as Application Expressions is
1233    that the Variable Expression predicate of the expression may be replaced
1234    with another Expression, such as a LambdaExpression, which would mean that
1235    the Predicate should be thought of as being applied to the arguments.
1236
1237    The logical expression reader will always curry arguments in a application expression.
1238    So, "\x y.see(x,y)(john,mary)" will be represented internally as
1239    "((\x y.(see(x))(y))(john))(mary)".  This simplifies the internals since
1240    there will always be exactly one argument in an application.
1241
1242    The str() method will usually print the curried forms of application
1243    expressions.  The one exception is when the the application expression is
1244    really a predicate expression (ie, underlying function is an
1245    ``AbstractVariableExpression``).  This means that the example from above
1246    will be returned as "(\x y.see(x,y)(john))(mary)".
1247    """
1248
1249    def __init__(self, function, argument):
1250        """
1251        :param function: ``Expression``, for the function expression
1252        :param argument: ``Expression``, for the argument
1253        """
1254        assert isinstance(function, Expression), "%s is not an Expression" % function
1255        assert isinstance(argument, Expression), "%s is not an Expression" % argument
1256        self.function = function
1257        self.argument = argument
1258
1259    def simplify(self):
1260        function = self.function.simplify()
1261        argument = self.argument.simplify()
1262        if isinstance(function, LambdaExpression):
1263            return function.term.replace(function.variable, argument).simplify()
1264        else:
1265            return self.__class__(function, argument)
1266
1267    @property
1268    def type(self):
1269        if isinstance(self.function.type, ComplexType):
1270            return self.function.type.second
1271        else:
1272            return ANY_TYPE
1273
1274    def _set_type(self, other_type=ANY_TYPE, signature=None):
1275        """:see Expression._set_type()"""
1276        assert isinstance(other_type, Type)
1277
1278        if signature is None:
1279            signature = defaultdict(list)
1280
1281        self.argument._set_type(ANY_TYPE, signature)
1282        try:
1283            self.function._set_type(
1284                ComplexType(self.argument.type, other_type), signature
1285            )
1286        except TypeResolutionException:
1287            raise TypeException(
1288                "The function '%s' is of type '%s' and cannot be applied "
1289                "to '%s' of type '%s'.  Its argument must match type '%s'."
1290                % (
1291                    self.function,
1292                    self.function.type,
1293                    self.argument,
1294                    self.argument.type,
1295                    self.function.type.first,
1296                )
1297            )
1298
1299    def findtype(self, variable):
1300        """:see Expression.findtype()"""
1301        assert isinstance(variable, Variable), "%s is not a Variable" % variable
1302        if self.is_atom():
1303            function, args = self.uncurry()
1304        else:
1305            # It's not a predicate expression ("P(x,y)"), so leave args curried
1306            function = self.function
1307            args = [self.argument]
1308
1309        found = [arg.findtype(variable) for arg in [function] + args]
1310
1311        unique = []
1312        for f in found:
1313            if f != ANY_TYPE:
1314                if unique:
1315                    for u in unique:
1316                        if f.matches(u):
1317                            break
1318                else:
1319                    unique.append(f)
1320
1321        if len(unique) == 1:
1322            return list(unique)[0]
1323        else:
1324            return ANY_TYPE
1325
1326    def constants(self):
1327        """:see: Expression.constants()"""
1328        if isinstance(self.function, AbstractVariableExpression):
1329            function_constants = set()
1330        else:
1331            function_constants = self.function.constants()
1332        return function_constants | self.argument.constants()
1333
1334    def predicates(self):
1335        """:see: Expression.predicates()"""
1336        if isinstance(self.function, ConstantExpression):
1337            function_preds = set([self.function.variable])
1338        else:
1339            function_preds = self.function.predicates()
1340        return function_preds | self.argument.predicates()
1341
1342    def visit(self, function, combinator):
1343        """:see: Expression.visit()"""
1344        return combinator([function(self.function), function(self.argument)])
1345
1346    def __eq__(self, other):
1347        return (
1348            isinstance(other, ApplicationExpression)
1349            and self.function == other.function
1350            and self.argument == other.argument
1351        )
1352
1353    def __ne__(self, other):
1354        return not self == other
1355
1356    __hash__ = Expression.__hash__
1357
1358    def __str__(self):
1359        # uncurry the arguments and find the base function
1360        if self.is_atom():
1361            function, args = self.uncurry()
1362            arg_str = ','.join("%s" % arg for arg in args)
1363        else:
1364            # Leave arguments curried
1365            function = self.function
1366            arg_str = "%s" % self.argument
1367
1368        function_str = "%s" % function
1369        parenthesize_function = False
1370        if isinstance(function, LambdaExpression):
1371            if isinstance(function.term, ApplicationExpression):
1372                if not isinstance(function.term.function, AbstractVariableExpression):
1373                    parenthesize_function = True
1374            elif not isinstance(function.term, BooleanExpression):
1375                parenthesize_function = True
1376        elif isinstance(function, ApplicationExpression):
1377            parenthesize_function = True
1378
1379        if parenthesize_function:
1380            function_str = Tokens.OPEN + function_str + Tokens.CLOSE
1381
1382        return function_str + Tokens.OPEN + arg_str + Tokens.CLOSE
1383
1384    def uncurry(self):
1385        """
1386        Uncurry this application expression
1387
1388        return: A tuple (base-function, arg-list)
1389        """
1390        function = self.function
1391        args = [self.argument]
1392        while isinstance(function, ApplicationExpression):
1393            # (\x.\y.sees(x,y)(john))(mary)
1394            args.insert(0, function.argument)
1395            function = function.function
1396        return (function, args)
1397
1398    @property
1399    def pred(self):
1400        """
1401        Return uncurried base-function.
1402        If this is an atom, then the result will be a variable expression.
1403        Otherwise, it will be a lambda expression.
1404        """
1405        return self.uncurry()[0]
1406
1407    @property
1408    def args(self):
1409        """
1410        Return uncurried arg-list
1411        """
1412        return self.uncurry()[1]
1413
1414    def is_atom(self):
1415        """
1416        Is this expression an atom (as opposed to a lambda expression applied
1417        to a term)?
1418        """
1419        return isinstance(self.pred, AbstractVariableExpression)
1420
1421
1422@total_ordering
1423@python_2_unicode_compatible
1424class AbstractVariableExpression(Expression):
1425    """This class represents a variable to be used as a predicate or entity"""
1426
1427    def __init__(self, variable):
1428        """
1429        :param variable: ``Variable``, for the variable
1430        """
1431        assert isinstance(variable, Variable), "%s is not a Variable" % variable
1432        self.variable = variable
1433
1434    def simplify(self):
1435        return self
1436
1437    def replace(self, variable, expression, replace_bound=False, alpha_convert=True):
1438        """:see: Expression.replace()"""
1439        assert isinstance(variable, Variable), "%s is not an Variable" % variable
1440        assert isinstance(expression, Expression), (
1441            "%s is not an Expression" % expression
1442        )
1443        if self.variable == variable:
1444            return expression
1445        else:
1446            return self
1447
1448    def _set_type(self, other_type=ANY_TYPE, signature=None):
1449        """:see Expression._set_type()"""
1450        assert isinstance(other_type, Type)
1451
1452        if signature is None:
1453            signature = defaultdict(list)
1454
1455        resolution = other_type
1456        for varEx in signature[self.variable.name]:
1457            resolution = varEx.type.resolve(resolution)
1458            if not resolution:
1459                raise InconsistentTypeHierarchyException(self)
1460
1461        signature[self.variable.name].append(self)
1462        for varEx in signature[self.variable.name]:
1463            varEx.type = resolution
1464
1465    def findtype(self, variable):
1466        """:see Expression.findtype()"""
1467        assert isinstance(variable, Variable), "%s is not a Variable" % variable
1468        if self.variable == variable:
1469            return self.type
1470        else:
1471            return ANY_TYPE
1472
1473    def predicates(self):
1474        """:see: Expression.predicates()"""
1475        return set()
1476
1477    def __eq__(self, other):
1478        """Allow equality between instances of ``AbstractVariableExpression``
1479        subtypes."""
1480        return (
1481            isinstance(other, AbstractVariableExpression)
1482            and self.variable == other.variable
1483        )
1484
1485    def __ne__(self, other):
1486        return not self == other
1487
1488    def __lt__(self, other):
1489        if not isinstance(other, AbstractVariableExpression):
1490            raise TypeError
1491        return self.variable < other.variable
1492
1493    __hash__ = Expression.__hash__
1494
1495    def __str__(self):
1496        return "%s" % self.variable
1497
1498
1499class IndividualVariableExpression(AbstractVariableExpression):
1500    """This class represents variables that take the form of a single lowercase
1501    character (other than 'e') followed by zero or more digits."""
1502
1503    def _set_type(self, other_type=ANY_TYPE, signature=None):
1504        """:see Expression._set_type()"""
1505        assert isinstance(other_type, Type)
1506
1507        if signature is None:
1508            signature = defaultdict(list)
1509
1510        if not other_type.matches(ENTITY_TYPE):
1511            raise IllegalTypeException(self, other_type, ENTITY_TYPE)
1512
1513        signature[self.variable.name].append(self)
1514
1515    def _get_type(self):
1516        return ENTITY_TYPE
1517
1518    type = property(_get_type, _set_type)
1519
1520    def free(self):
1521        """:see: Expression.free()"""
1522        return set([self.variable])
1523
1524    def constants(self):
1525        """:see: Expression.constants()"""
1526        return set()
1527
1528
1529class FunctionVariableExpression(AbstractVariableExpression):
1530    """This class represents variables that take the form of a single uppercase
1531    character followed by zero or more digits."""
1532
1533    type = ANY_TYPE
1534
1535    def free(self):
1536        """:see: Expression.free()"""
1537        return set([self.variable])
1538
1539    def constants(self):
1540        """:see: Expression.constants()"""
1541        return set()
1542
1543
1544class EventVariableExpression(IndividualVariableExpression):
1545    """This class represents variables that take the form of a single lowercase
1546    'e' character followed by zero or more digits."""
1547
1548    type = EVENT_TYPE
1549
1550
1551class ConstantExpression(AbstractVariableExpression):
1552    """This class represents variables that do not take the form of a single
1553    character followed by zero or more digits."""
1554
1555    type = ENTITY_TYPE
1556
1557    def _set_type(self, other_type=ANY_TYPE, signature=None):
1558        """:see Expression._set_type()"""
1559        assert isinstance(other_type, Type)
1560
1561        if signature is None:
1562            signature = defaultdict(list)
1563
1564        if other_type == ANY_TYPE:
1565            # entity type by default, for individuals
1566            resolution = ENTITY_TYPE
1567        else:
1568            resolution = other_type
1569            if self.type != ENTITY_TYPE:
1570                resolution = resolution.resolve(self.type)
1571
1572        for varEx in signature[self.variable.name]:
1573            resolution = varEx.type.resolve(resolution)
1574            if not resolution:
1575                raise InconsistentTypeHierarchyException(self)
1576
1577        signature[self.variable.name].append(self)
1578        for varEx in signature[self.variable.name]:
1579            varEx.type = resolution
1580
1581    def free(self):
1582        """:see: Expression.free()"""
1583        return set()
1584
1585    def constants(self):
1586        """:see: Expression.constants()"""
1587        return set([self.variable])
1588
1589
1590def VariableExpression(variable):
1591    """
1592    This is a factory method that instantiates and returns a subtype of
1593    ``AbstractVariableExpression`` appropriate for the given variable.
1594    """
1595    assert isinstance(variable, Variable), "%s is not a Variable" % variable
1596    if is_indvar(variable.name):
1597        return IndividualVariableExpression(variable)
1598    elif is_funcvar(variable.name):
1599        return FunctionVariableExpression(variable)
1600    elif is_eventvar(variable.name):
1601        return EventVariableExpression(variable)
1602    else:
1603        return ConstantExpression(variable)
1604
1605
1606class VariableBinderExpression(Expression):
1607    """This an abstract class for any Expression that binds a variable in an
1608    Expression.  This includes LambdaExpressions and Quantified Expressions"""
1609
1610    def __init__(self, variable, term):
1611        """
1612        :param variable: ``Variable``, for the variable
1613        :param term: ``Expression``, for the term
1614        """
1615        assert isinstance(variable, Variable), "%s is not a Variable" % variable
1616        assert isinstance(term, Expression), "%s is not an Expression" % term
1617        self.variable = variable
1618        self.term = term
1619
1620    def replace(self, variable, expression, replace_bound=False, alpha_convert=True):
1621        """:see: Expression.replace()"""
1622        assert isinstance(variable, Variable), "%s is not a Variable" % variable
1623        assert isinstance(expression, Expression), (
1624            "%s is not an Expression" % expression
1625        )
1626        # if the bound variable is the thing being replaced
1627        if self.variable == variable:
1628            if replace_bound:
1629                assert isinstance(expression, AbstractVariableExpression), (
1630                    "%s is not a AbstractVariableExpression" % expression
1631                )
1632                return self.__class__(
1633                    expression.variable,
1634                    self.term.replace(variable, expression, True, alpha_convert),
1635                )
1636            else:
1637                return self
1638        else:
1639            # if the bound variable appears in the expression, then it must
1640            # be alpha converted to avoid a conflict
1641            if alpha_convert and self.variable in expression.free():
1642                self = self.alpha_convert(unique_variable(pattern=self.variable))
1643
1644            # replace in the term
1645            return self.__class__(
1646                self.variable,
1647                self.term.replace(variable, expression, replace_bound, alpha_convert),
1648            )
1649
1650    def alpha_convert(self, newvar):
1651        """Rename all occurrences of the variable introduced by this variable
1652        binder in the expression to ``newvar``.
1653        :param newvar: ``Variable``, for the new variable
1654        """
1655        assert isinstance(newvar, Variable), "%s is not a Variable" % newvar
1656        return self.__class__(
1657            newvar, self.term.replace(self.variable, VariableExpression(newvar), True)
1658        )
1659
1660    def free(self):
1661        """:see: Expression.free()"""
1662        return self.term.free() - set([self.variable])
1663
1664    def findtype(self, variable):
1665        """:see Expression.findtype()"""
1666        assert isinstance(variable, Variable), "%s is not a Variable" % variable
1667        if variable == self.variable:
1668            return ANY_TYPE
1669        else:
1670            return self.term.findtype(variable)
1671
1672    def visit(self, function, combinator):
1673        """:see: Expression.visit()"""
1674        return combinator([function(self.term)])
1675
1676    def visit_structured(self, function, combinator):
1677        """:see: Expression.visit_structured()"""
1678        return combinator(self.variable, function(self.term))
1679
1680    def __eq__(self, other):
1681        r"""Defines equality modulo alphabetic variance.  If we are comparing
1682        \x.M  and \y.N, then check equality of M and N[x/y]."""
1683        if isinstance(self, other.__class__) or isinstance(other, self.__class__):
1684            if self.variable == other.variable:
1685                return self.term == other.term
1686            else:
1687                # Comparing \x.M  and \y.N.  Relabel y in N with x and continue.
1688                varex = VariableExpression(self.variable)
1689                return self.term == other.term.replace(other.variable, varex)
1690        else:
1691            return False
1692
1693    def __ne__(self, other):
1694        return not self == other
1695
1696    __hash__ = Expression.__hash__
1697
1698
1699@python_2_unicode_compatible
1700class LambdaExpression(VariableBinderExpression):
1701    @property
1702    def type(self):
1703        return ComplexType(self.term.findtype(self.variable), self.term.type)
1704
1705    def _set_type(self, other_type=ANY_TYPE, signature=None):
1706        """:see Expression._set_type()"""
1707        assert isinstance(other_type, Type)
1708
1709        if signature is None:
1710            signature = defaultdict(list)
1711
1712        self.term._set_type(other_type.second, signature)
1713        if not self.type.resolve(other_type):
1714            raise TypeResolutionException(self, other_type)
1715
1716    def __str__(self):
1717        variables = [self.variable]
1718        term = self.term
1719        while term.__class__ == self.__class__:
1720            variables.append(term.variable)
1721            term = term.term
1722        return (
1723            Tokens.LAMBDA
1724            + ' '.join("%s" % v for v in variables)
1725            + Tokens.DOT
1726            + "%s" % term
1727        )
1728
1729
1730@python_2_unicode_compatible
1731class QuantifiedExpression(VariableBinderExpression):
1732    @property
1733    def type(self):
1734        return TRUTH_TYPE
1735
1736    def _set_type(self, other_type=ANY_TYPE, signature=None):
1737        """:see Expression._set_type()"""
1738        assert isinstance(other_type, Type)
1739
1740        if signature is None:
1741            signature = defaultdict(list)
1742
1743        if not other_type.matches(TRUTH_TYPE):
1744            raise IllegalTypeException(self, other_type, TRUTH_TYPE)
1745        self.term._set_type(TRUTH_TYPE, signature)
1746
1747    def __str__(self):
1748        variables = [self.variable]
1749        term = self.term
1750        while term.__class__ == self.__class__:
1751            variables.append(term.variable)
1752            term = term.term
1753        return (
1754            self.getQuantifier()
1755            + ' '
1756            + ' '.join("%s" % v for v in variables)
1757            + Tokens.DOT
1758            + "%s" % term
1759        )
1760
1761
1762class ExistsExpression(QuantifiedExpression):
1763    def getQuantifier(self):
1764        return Tokens.EXISTS
1765
1766
1767class AllExpression(QuantifiedExpression):
1768    def getQuantifier(self):
1769        return Tokens.ALL
1770
1771
1772@python_2_unicode_compatible
1773class NegatedExpression(Expression):
1774    def __init__(self, term):
1775        assert isinstance(term, Expression), "%s is not an Expression" % term
1776        self.term = term
1777
1778    @property
1779    def type(self):
1780        return TRUTH_TYPE
1781
1782    def _set_type(self, other_type=ANY_TYPE, signature=None):
1783        """:see Expression._set_type()"""
1784        assert isinstance(other_type, Type)
1785
1786        if signature is None:
1787            signature = defaultdict(list)
1788
1789        if not other_type.matches(TRUTH_TYPE):
1790            raise IllegalTypeException(self, other_type, TRUTH_TYPE)
1791        self.term._set_type(TRUTH_TYPE, signature)
1792
1793    def findtype(self, variable):
1794        assert isinstance(variable, Variable), "%s is not a Variable" % variable
1795        return self.term.findtype(variable)
1796
1797    def visit(self, function, combinator):
1798        """:see: Expression.visit()"""
1799        return combinator([function(self.term)])
1800
1801    def negate(self):
1802        """:see: Expression.negate()"""
1803        return self.term
1804
1805    def __eq__(self, other):
1806        return isinstance(other, NegatedExpression) and self.term == other.term
1807
1808    def __ne__(self, other):
1809        return not self == other
1810
1811    __hash__ = Expression.__hash__
1812
1813    def __str__(self):
1814        return Tokens.NOT + "%s" % self.term
1815
1816
1817@python_2_unicode_compatible
1818class BinaryExpression(Expression):
1819    def __init__(self, first, second):
1820        assert isinstance(first, Expression), "%s is not an Expression" % first
1821        assert isinstance(second, Expression), "%s is not an Expression" % second
1822        self.first = first
1823        self.second = second
1824
1825    @property
1826    def type(self):
1827        return TRUTH_TYPE
1828
1829    def findtype(self, variable):
1830        """:see Expression.findtype()"""
1831        assert isinstance(variable, Variable), "%s is not a Variable" % variable
1832        f = self.first.findtype(variable)
1833        s = self.second.findtype(variable)
1834        if f == s or s == ANY_TYPE:
1835            return f
1836        elif f == ANY_TYPE:
1837            return s
1838        else:
1839            return ANY_TYPE
1840
1841    def visit(self, function, combinator):
1842        """:see: Expression.visit()"""
1843        return combinator([function(self.first), function(self.second)])
1844
1845    def __eq__(self, other):
1846        return (
1847            (isinstance(self, other.__class__) or isinstance(other, self.__class__))
1848            and self.first == other.first
1849            and self.second == other.second
1850        )
1851
1852    def __ne__(self, other):
1853        return not self == other
1854
1855    __hash__ = Expression.__hash__
1856
1857    def __str__(self):
1858        first = self._str_subex(self.first)
1859        second = self._str_subex(self.second)
1860        return Tokens.OPEN + first + ' ' + self.getOp() + ' ' + second + Tokens.CLOSE
1861
1862    def _str_subex(self, subex):
1863        return "%s" % subex
1864
1865
1866class BooleanExpression(BinaryExpression):
1867    def _set_type(self, other_type=ANY_TYPE, signature=None):
1868        """:see Expression._set_type()"""
1869        assert isinstance(other_type, Type)
1870
1871        if signature is None:
1872            signature = defaultdict(list)
1873
1874        if not other_type.matches(TRUTH_TYPE):
1875            raise IllegalTypeException(self, other_type, TRUTH_TYPE)
1876        self.first._set_type(TRUTH_TYPE, signature)
1877        self.second._set_type(TRUTH_TYPE, signature)
1878
1879
1880class AndExpression(BooleanExpression):
1881    """This class represents conjunctions"""
1882
1883    def getOp(self):
1884        return Tokens.AND
1885
1886    def _str_subex(self, subex):
1887        s = "%s" % subex
1888        if isinstance(subex, AndExpression):
1889            return s[1:-1]
1890        return s
1891
1892
1893class OrExpression(BooleanExpression):
1894    """This class represents disjunctions"""
1895
1896    def getOp(self):
1897        return Tokens.OR
1898
1899    def _str_subex(self, subex):
1900        s = "%s" % subex
1901        if isinstance(subex, OrExpression):
1902            return s[1:-1]
1903        return s
1904
1905
1906class ImpExpression(BooleanExpression):
1907    """This class represents implications"""
1908
1909    def getOp(self):
1910        return Tokens.IMP
1911
1912
1913class IffExpression(BooleanExpression):
1914    """This class represents biconditionals"""
1915
1916    def getOp(self):
1917        return Tokens.IFF
1918
1919
1920class EqualityExpression(BinaryExpression):
1921    """This class represents equality expressions like "(x = y)"."""
1922
1923    def _set_type(self, other_type=ANY_TYPE, signature=None):
1924        """:see Expression._set_type()"""
1925        assert isinstance(other_type, Type)
1926
1927        if signature is None:
1928            signature = defaultdict(list)
1929
1930        if not other_type.matches(TRUTH_TYPE):
1931            raise IllegalTypeException(self, other_type, TRUTH_TYPE)
1932        self.first._set_type(ENTITY_TYPE, signature)
1933        self.second._set_type(ENTITY_TYPE, signature)
1934
1935    def getOp(self):
1936        return Tokens.EQ
1937
1938
1939### Utilities
1940
1941
1942class LogicalExpressionException(Exception):
1943    def __init__(self, index, message):
1944        self.index = index
1945        Exception.__init__(self, message)
1946
1947
1948class UnexpectedTokenException(LogicalExpressionException):
1949    def __init__(self, index, unexpected=None, expected=None, message=None):
1950        if unexpected and expected:
1951            msg = "Unexpected token: '%s'.  " "Expected token '%s'." % (
1952                unexpected,
1953                expected,
1954            )
1955        elif unexpected:
1956            msg = "Unexpected token: '%s'." % unexpected
1957            if message:
1958                msg += '  ' + message
1959        else:
1960            msg = "Expected token '%s'." % expected
1961        LogicalExpressionException.__init__(self, index, msg)
1962
1963
1964class ExpectedMoreTokensException(LogicalExpressionException):
1965    def __init__(self, index, message=None):
1966        if not message:
1967            message = 'More tokens expected.'
1968        LogicalExpressionException.__init__(
1969            self, index, 'End of input found.  ' + message
1970        )
1971
1972
1973def is_indvar(expr):
1974    """
1975    An individual variable must be a single lowercase character other than 'e',
1976    followed by zero or more digits.
1977
1978    :param expr: str
1979    :return: bool True if expr is of the correct form
1980    """
1981    assert isinstance(expr, string_types), "%s is not a string" % expr
1982    return re.match(r'^[a-df-z]\d*$', expr) is not None
1983
1984
1985def is_funcvar(expr):
1986    """
1987    A function variable must be a single uppercase character followed by
1988    zero or more digits.
1989
1990    :param expr: str
1991    :return: bool True if expr is of the correct form
1992    """
1993    assert isinstance(expr, string_types), "%s is not a string" % expr
1994    return re.match(r'^[A-Z]\d*$', expr) is not None
1995
1996
1997def is_eventvar(expr):
1998    """
1999    An event variable must be a single lowercase 'e' character followed by
2000    zero or more digits.
2001
2002    :param expr: str
2003    :return: bool True if expr is of the correct form
2004    """
2005    assert isinstance(expr, string_types), "%s is not a string" % expr
2006    return re.match(r'^e\d*$', expr) is not None
2007
2008
2009def demo():
2010    lexpr = Expression.fromstring
2011    print('=' * 20 + 'Test reader' + '=' * 20)
2012    print(lexpr(r'john'))
2013    print(lexpr(r'man(x)'))
2014    print(lexpr(r'-man(x)'))
2015    print(lexpr(r'(man(x) & tall(x) & walks(x))'))
2016    print(lexpr(r'exists x.(man(x) & tall(x) & walks(x))'))
2017    print(lexpr(r'\x.man(x)'))
2018    print(lexpr(r'\x.man(x)(john)'))
2019    print(lexpr(r'\x y.sees(x,y)'))
2020    print(lexpr(r'\x y.sees(x,y)(a,b)'))
2021    print(lexpr(r'(\x.exists y.walks(x,y))(x)'))
2022    print(lexpr(r'exists x.x = y'))
2023    print(lexpr(r'exists x.(x = y)'))
2024    print(lexpr('P(x) & x=y & P(y)'))
2025    print(lexpr(r'\P Q.exists x.(P(x) & Q(x))'))
2026    print(lexpr(r'man(x) <-> tall(x)'))
2027
2028    print('=' * 20 + 'Test simplify' + '=' * 20)
2029    print(lexpr(r'\x.\y.sees(x,y)(john)(mary)').simplify())
2030    print(lexpr(r'\x.\y.sees(x,y)(john, mary)').simplify())
2031    print(lexpr(r'all x.(man(x) & (\x.exists y.walks(x,y))(x))').simplify())
2032    print(lexpr(r'(\P.\Q.exists x.(P(x) & Q(x)))(\x.dog(x))(\x.bark(x))').simplify())
2033
2034    print('=' * 20 + 'Test alpha conversion and binder expression equality' + '=' * 20)
2035    e1 = lexpr('exists x.P(x)')
2036    print(e1)
2037    e2 = e1.alpha_convert(Variable('z'))
2038    print(e2)
2039    print(e1 == e2)
2040
2041
2042def demo_errors():
2043    print('=' * 20 + 'Test reader errors' + '=' * 20)
2044    demoException('(P(x) & Q(x)')
2045    demoException('((P(x) &) & Q(x))')
2046    demoException('P(x) -> ')
2047    demoException('P(x')
2048    demoException('P(x,')
2049    demoException('P(x,)')
2050    demoException('exists')
2051    demoException('exists x.')
2052    demoException('\\')
2053    demoException('\\ x y.')
2054    demoException('P(x)Q(x)')
2055    demoException('(P(x)Q(x)')
2056    demoException('exists x -> y')
2057
2058
2059def demoException(s):
2060    try:
2061        Expression.fromstring(s)
2062    except LogicalExpressionException as e:
2063        print("%s: %s" % (e.__class__.__name__, e))
2064
2065
2066def printtype(ex):
2067    print("%s : %s" % (ex.str(), ex.type))
2068
2069
2070if __name__ == '__main__':
2071    demo()
2072#    demo_errors()
2073