1###################################
2# Parse an !assert directive      #
3# By Scott Pakin <pakin@lanl.gov> #
4###################################
5
6import qmasm
7import re
8import sys
9
10class AST(object):
11    "Represent an abstract syntax tree."
12
13    def __init__(self, qmasm_obj, type, value, kids=[]):
14        self.qmasm = qmasm_obj
15        self.type = type
16        self.value = value
17        self.kids = kids
18        self.code = lambda isb: self.qmasm.abend("Internal error: Attempt to evaluate an AST without compiling it first")     # Function that evaluates the AST given a mapping from identifiers to bits
19        self._str = None   # Memoized string representation
20        self.pin_parser = qmasm.parse.PinParser()
21
22    def _needs_parens(self):
23        "Return True if an AST node should be parenthesized."
24        return self.type == "factor" and self.kids[0].type == "conn"
25
26    def _str_helper(self):
27        "Do most of the work for the __str__ method."
28        # Conditionally parenthesize all child strings.
29        nkids = len(self.kids)
30        kids_str = [str(k) for k in self.kids]
31        for i in range(nkids):
32            if self.kids[i]._needs_parens():
33                kids_str[i] = "(" + kids_str[i] + ")"
34
35        # Return ourself as a string.
36        if nkids == 0:
37            return str(self.value)
38        if nkids == 1:
39            if self.type == "unary" and self.value != "id":
40                return "%s%s" % (self.value, kids_str[0])
41            return kids_str[0]
42        if nkids == 2:
43            if self.value in ["*", "/", "%", "&", "<<", ">>", "**"]:
44                return "%s%s%s" % (kids_str[0], self.value, kids_str[1])
45            else:
46                return "%s %s %s" % (kids_str[0], self.value, kids_str[1])
47        if nkids == 3:
48            if self.type == "if_expr":
49                return "if %s then %s else %s endif" % (str(self.kids[0]), str(self.kids[1]), str(self.kids[2]))
50        raise Exception("Internal error parsing (%s, %s)" % (repr(self.type), repr(self.value)))
51
52    def __str__(self):
53        if self._str == None:
54            self._str = self._str_helper()
55        return self._str
56
57    def prefix_identifiers(self, prefix, next_prefix):
58        "Prefix every identifier with a given string."
59        if self.type == "ident":
60            self.value = self.qmasm.apply_prefix(self.value, prefix, next_prefix)
61        else:
62            for k in self.kids:
63                k.prefix_identifiers(prefix, next_prefix)
64
65    def replace_ident(self, old_ident, new_ident):
66        "Replace every occurrence of one identifer with another."
67        if self.type == "ident":
68            if self.value == old_ident:
69                self.value = new_ident
70        else:
71            for k in self.kids:
72                k.replace_ident(old_ident, new_ident)
73
74    class EvaluationError(Exception):
75        "Represent an exception thrown during AST evaluation."
76        pass
77
78    def _evaluate_ident(self, i2b):
79        "Evaluate a variable, including array variables."
80        val = 0
81        for v in self.pin_parser.parse_lhs(self.value):
82            try:
83                bit = i2b[v]
84                if bit == None:
85                    raise self.EvaluationError("Unused variable %s" % v)
86                val = val*2 + bit
87            except KeyError:
88                raise self.EvaluationError("Undefined variable %s" % v)
89        return val
90
91    def _compile_unary(self, kvals):
92        "Compile a unary expression."
93        if self.value == "-":
94            return lambda i2b: -kvals[0](i2b)
95        elif self.value == "~":
96            return lambda i2b: ~kvals[0](i2b)
97        elif self.value == "!":
98            return lambda i2b: int(kvals[0](i2b) == 0)
99        elif self.value in ["+", "id"]:
100            return lambda i2b: kvals[0](i2b)
101        else:
102            raise self.EvaluationError('Internal error compiling unary "%s"' % self.value)
103
104    def _evaluate_power(self, base, exp):
105        "Raise one integer to the power of another."
106        if exp < 0:
107            raise self.EvaluationError("Negative powers (%d) are not allowed" % exp)
108        return base**exp
109
110    def _compile_arith(self, kvals):
111        "Compile an arithmetic expression."
112        if self.value == "+":
113            return lambda i2b: kvals[0](i2b) + kvals[1](i2b)
114        elif self.value == "-":
115            return lambda i2b: kvals[0](i2b) - kvals[1](i2b)
116        elif self.value == "*":
117            return lambda i2b: kvals[0](i2b) * kvals[1](i2b)
118        elif self.value == "/":
119            return lambda i2b: kvals[0](i2b) // kvals[1](i2b)
120        elif self.value == "%":
121            return lambda i2b: kvals[0](i2b) % kvals[1](i2b)
122        elif self.value == "&":
123            return lambda i2b: kvals[0](i2b) & kvals[1](i2b)
124        elif self.value == "|":
125            return lambda i2b: kvals[0](i2b) | kvals[1](i2b)
126        elif self.value == "^":
127            return lambda i2b: kvals[0](i2b) ^ kvals[1](i2b)
128        elif self.value == "<<":
129            return lambda i2b: kvals[0](i2b) << kvals[1](i2b)
130        elif self.value == ">>":
131            return lambda i2b: kvals[0](i2b) >> kvals[1](i2b)
132        elif self.value == "**":
133            return lambda i2b: self._evaluate_power(kvals[0](i2b), kvals[1](i2b))
134        else:
135            raise self.EvaluationError("Internal error compiling arithmetic operator %s" % self.value)
136
137    def _compile_rel(self, kvals):
138        "Compile a relational expression."
139        if self.value == "=":
140            return lambda i2b: kvals[0](i2b) == kvals[1](i2b)
141        elif self.value == "/=":
142            return lambda i2b: kvals[0](i2b) != kvals[1](i2b)
143        elif self.value == "<":
144            return lambda i2b: kvals[0](i2b) < kvals[1](i2b)
145        elif self.value == "<=":
146            return lambda i2b: kvals[0](i2b) <= kvals[1](i2b)
147        elif self.value == ">":
148            return lambda i2b: kvals[0](i2b) > kvals[1](i2b)
149        elif self.value == ">=":
150            return lambda i2b: kvals[0](i2b) >= kvals[1](i2b)
151        else:
152            raise self.EvaluationError("Internal error compiling relational operator %s" % self.value)
153
154    def _compile_conn(self, kvals):
155        "Compile a logical connective."
156        if self.value == "&&":
157            return lambda i2b: kvals[0](i2b) and kvals[1](i2b)
158        elif self.value == "||":
159            return lambda i2b: kvals[0](i2b) or kvals[1](i2b)
160        else:
161            raise self.EvaluationError("Internal error compiling logical connective %s" % self.value)
162
163    def _evaluate_if_expr(self, i2b, kvals):
164        if kvals[0](i2b):
165            return kvals[1](i2b)
166        else:
167            return kvals[2](i2b)
168
169    def _compile_if_expr(self, kvals):
170        "Compile an if...then...else expression."
171        return lambda i2b: self._evaluate_if_expr(i2b, kvals)
172
173    def _compile_node(self):
174
175        """Compile the AST to a function that returns either True or False
176        given a mapping from identifiers to bits."""
177        kvals = [k._compile_node() for k in self.kids]
178        if self.type == "ident":
179            # Variable
180            return lambda i2b: self._evaluate_ident(i2b)
181        elif self.type == "int":
182            # Constant
183            return lambda i2b: self.value
184        elif self.type == "unary":
185            # Unary expression
186            return self._compile_unary(kvals)
187        elif len(kvals) == 1:
188            # All other single-child nodes return their child unmodified.
189            return kvals[0]
190        elif self.type in ["power", "term", "expr"]:
191            return self._compile_arith(kvals)
192        elif self.type == "rel":
193            return self._compile_rel(kvals)
194        elif self.type == "conn":
195            return self._compile_conn(kvals)
196        elif self.type == "if_expr":
197            return self._compile_if_expr(kvals)
198        else:
199            raise self.EvaluationError("Internal error compiling AST node of type %s, value %s" % (repr(self.type), repr(self.value)))
200
201    def compile(self):
202        "Compile an AST for faster evaluation."
203        self.code = self._compile_node()
204
205    def evaluate(self, i2b):
206        "Evaluate the AST to a value, given a mapping from identifiers to bits."
207        try:
208            return self.code(i2b)
209        except self.EvaluationError as e:
210            self.qmasm.abend("%s in assertion %s" % (e, self))
211
212class AssertParser(object):
213    int_re = re.compile(r'\d+')
214    conn_re = re.compile(r'\|\||&&')
215    rel_re = re.compile(r'/?=|[<>]=?')
216    arith_re = re.compile(r'[-+/%&\|^~!]|>>|<<|\*\*?')
217    keyword_re = re.compile(r'\b(if|then|else|endif)\b')
218
219    def __init__(self, qmasm):
220        self.qmasm = qmasm
221
222    class ParseError(Exception):
223        pass
224
225    def lex(self, s):
226        "Split a string into tokens (tuples of type and value)."
227        tokens = []
228        s = s.lstrip()
229        while len(s) > 0:
230            # Match parentheses.
231            if s[0] == "(":
232                tokens.append(("lparen", "("))
233                s = s[1:].lstrip()
234                continue
235            if s[0] == ")":
236                tokens.append(("rparen", ")"))
237                s = s[1:].lstrip()
238                continue
239
240            # Match keywords.
241            mo = self.keyword_re.match(s)
242            if mo != None:
243                match = mo.group(0)
244                tokens.append((match, match))
245                s = s[len(match):].lstrip()
246                continue
247
248            # Match positive integers.
249            mo = self.int_re.match(s)
250            if mo != None:
251                match = mo.group(0)
252                tokens.append(("int", int(match)))
253                s = s[len(match):].lstrip()
254                continue
255
256            # Match connectives.
257            mo = self.conn_re.match(s)
258            if mo != None:
259                match = mo.group(0)
260                tokens.append(("conn", match))
261                s = s[len(match):].lstrip()
262                continue
263
264            # Match "<<" and ">>" before we match "<" and ">".
265            if len(s) >= 2 and (s[:2] == "<<" or s[:2] == ">>"):
266                tokens.append(("arith", s[:2]))
267                s = s[2:].lstrip()
268                continue
269
270            # Match relational operators.
271            mo = self.rel_re.match(s)
272            if mo != None:
273                match = mo.group(0)
274                tokens.append(("rel", match))
275                s = s[len(match):].lstrip()
276                continue
277
278            # Match "**" before we match "*".
279            if len(s) >= 2 and s[:2] == "**":
280                tokens.append(("power", s[:2]))
281                s = s[2:].lstrip()
282                continue
283
284            # Match arithmetic operators.
285            mo = self.arith_re.match(s)
286            if mo != None:
287                match = mo.group(0)
288                tokens.append(("arith", match))
289                s = s[len(match):].lstrip()
290                continue
291
292            # Everything else is an identifier.
293            mo = self.qmasm.ident_re.match(s)
294            if mo != None:
295                match = mo.group(0)
296                tokens.append(("ident", match))
297                s = s[len(match):].lstrip()
298                continue
299            raise self.ParseError("Failed to parse %s" % s)
300        tokens.append(("EOF", "EOF"))
301        return tokens
302
303    def advance(self):
304        "Advance to the next symbol."
305        self.tokidx += 1
306        self.sym = self.tokens[self.tokidx]
307
308    def accept(self, ty):
309        """Advance to the next token if the current token matches a given
310        token type and return True.  Otherwise, return False."""
311        if self.sym[0] == ty:
312            self.advance()
313            return True
314        return False
315
316    def expect(self, ty):
317        """Advance to the next token if the current token matches a given
318        token.  Otherwise, fail."""
319        if not self.accept(ty):
320            raise self.ParseError("Expected %s but saw %s" % (ty, repr(self.sym[1])))
321
322    def generic_operator(self, return_type, child_method, sym_type, valid_ops):
323        "Match one or more somethings to produce something else."
324        # Produce a list of ASTs representing children.
325        c = child_method()
326        ops = [self.sym[1]]
327        asts = [c]
328        while self.sym[0] == sym_type and ops[-1] in valid_ops:
329            self.advance()
330            c = child_method()
331            ops.append(self.sym[1])
332            asts.append(c)
333
334        # Handle the trivial case of the identity operation.
335        if len(asts) == 1:
336            return AST(self.qmasm, return_type, None, asts)
337
338        # Merge the ASTs in a left-associative fashion into a single AST.
339        ops.pop()
340        while len(asts) > 1:
341            asts = [AST(self.qmasm, return_type, ops[0], [asts[0], asts[1]])] + asts[2:]
342            ops.pop(0)
343        return asts[0]
344
345    def if_expr(self):
346        "Return an if...then...else expression."
347        self.expect("if")
348        cond = self.conjunction()
349        self.expect("then")
350        then_expr = self.expression()
351        self.expect("else")
352        else_expr = self.expression()
353        self.expect("endif")
354        return AST(self.qmasm, "if_expr", None, [cond, then_expr, else_expr])
355
356    def factor(self):
357        "Return a factor (variable, integer, or expression)."
358        val = self.sym[1]
359        if self.accept("ident"):
360            child = AST(self.qmasm, "ident", val)
361        elif self.accept("int"):
362            child = AST(self.qmasm, "int", val)
363        elif self.accept("lparen"):
364            child = self.disjunction()
365            self.expect("rparen")
366        elif self.sym[0] == "arith":
367            child = self.unary()
368        elif self.sym[0] == "if":
369            child = self.if_expr()
370        elif val == "EOF":
371            raise self.ParseError("Parse error at end of expression")
372        else:
373            raise self.ParseError('Parse error at "%s"' % val)
374        return AST(self.qmasm, "factor", None, [child])
375
376    def power(self):
377        "Return a factor or a factor raised to the power of a second factor."
378        f1 = self.factor()
379        op = self.sym[1]
380        if self.sym[0] == "power" and op == "**":
381            self.advance()
382            f2 = self.power()
383            return AST(self.qmasm, "power", op, [f1, f2])
384        return AST(self.qmasm, "power", None, [f1])
385
386    def unary(self):
387        "Return a unary operator applied to a power."
388        op = self.sym[1]
389        if op in ["+", "-", "~", "!"]:
390            self.advance()
391        else:
392            op = "id"
393        return AST(self.qmasm, "unary", op, [self.power()])
394
395    def term(self):
396        "Return a term (product of one or more unaries)."
397        return self.generic_operator("term", self.unary, "arith", ["*", "/", "%", "&", "<<", ">>"])
398
399    def expression(self):
400        "Return an expression (sum of one or more terms)."
401        return self.generic_operator("expr", self.term, "arith", ["+", "-", "|", "^"])
402
403    def comparison(self):
404        "Return a comparison of exactly two expressions."
405        e1 = self.expression()
406        op = self.sym[1]
407        if self.sym[0] != "rel":
408            return AST(self.qmasm, "rel", None, [e1])
409        self.advance()
410        e2 = self.expression()
411        return AST(self.qmasm, "rel", op, [e1, e2])
412
413    def conjunction(self):
414        "Return a conjunction (logical AND of one or more comparisons)."
415        return self.generic_operator("conn", self.comparison, "conn", ["&&"])
416
417    def disjunction(self):
418        "Return a disjunction (logical OR of one or more conjunctions)."
419        return self.generic_operator("conn", self.conjunction, "conn", ["||"])
420
421    def parse(self, s):
422        "Parse a relational expression into an AST"
423        self.tokens = self.lex(s)
424        self.tokidx = -1
425        self.advance()
426        try:
427            ast = self.disjunction()
428            if self.sym[0] != "EOF":
429                raise self.ParseError('Parse error at "%s"' % self.sym[1])
430        except self.ParseError as e:
431            qmasm.abend('%s in "%s"' % (e, s))
432        return ast
433