1################################### 2# Parse an !assert directive # 3# By Scott Pakin <pakin@lanl.gov> # 4################################### 5 6import qmasm 7import re 8import sys 9 10class AST(object): 11 "Represent an abstract syntax tree." 12 13 def __init__(self, qmasm_obj, type, value, kids=[]): 14 self.qmasm = qmasm_obj 15 self.type = type 16 self.value = value 17 self.kids = kids 18 self.code = lambda isb: self.qmasm.abend("Internal error: Attempt to evaluate an AST without compiling it first") # Function that evaluates the AST given a mapping from identifiers to bits 19 self._str = None # Memoized string representation 20 self.pin_parser = qmasm.parse.PinParser() 21 22 def _needs_parens(self): 23 "Return True if an AST node should be parenthesized." 24 return self.type == "factor" and self.kids[0].type == "conn" 25 26 def _str_helper(self): 27 "Do most of the work for the __str__ method." 28 # Conditionally parenthesize all child strings. 29 nkids = len(self.kids) 30 kids_str = [str(k) for k in self.kids] 31 for i in range(nkids): 32 if self.kids[i]._needs_parens(): 33 kids_str[i] = "(" + kids_str[i] + ")" 34 35 # Return ourself as a string. 36 if nkids == 0: 37 return str(self.value) 38 if nkids == 1: 39 if self.type == "unary" and self.value != "id": 40 return "%s%s" % (self.value, kids_str[0]) 41 return kids_str[0] 42 if nkids == 2: 43 if self.value in ["*", "/", "%", "&", "<<", ">>", "**"]: 44 return "%s%s%s" % (kids_str[0], self.value, kids_str[1]) 45 else: 46 return "%s %s %s" % (kids_str[0], self.value, kids_str[1]) 47 if nkids == 3: 48 if self.type == "if_expr": 49 return "if %s then %s else %s endif" % (str(self.kids[0]), str(self.kids[1]), str(self.kids[2])) 50 raise Exception("Internal error parsing (%s, %s)" % (repr(self.type), repr(self.value))) 51 52 def __str__(self): 53 if self._str == None: 54 self._str = self._str_helper() 55 return self._str 56 57 def prefix_identifiers(self, prefix, next_prefix): 58 "Prefix every identifier with a given string." 59 if self.type == "ident": 60 self.value = self.qmasm.apply_prefix(self.value, prefix, next_prefix) 61 else: 62 for k in self.kids: 63 k.prefix_identifiers(prefix, next_prefix) 64 65 def replace_ident(self, old_ident, new_ident): 66 "Replace every occurrence of one identifer with another." 67 if self.type == "ident": 68 if self.value == old_ident: 69 self.value = new_ident 70 else: 71 for k in self.kids: 72 k.replace_ident(old_ident, new_ident) 73 74 class EvaluationError(Exception): 75 "Represent an exception thrown during AST evaluation." 76 pass 77 78 def _evaluate_ident(self, i2b): 79 "Evaluate a variable, including array variables." 80 val = 0 81 for v in self.pin_parser.parse_lhs(self.value): 82 try: 83 bit = i2b[v] 84 if bit == None: 85 raise self.EvaluationError("Unused variable %s" % v) 86 val = val*2 + bit 87 except KeyError: 88 raise self.EvaluationError("Undefined variable %s" % v) 89 return val 90 91 def _compile_unary(self, kvals): 92 "Compile a unary expression." 93 if self.value == "-": 94 return lambda i2b: -kvals[0](i2b) 95 elif self.value == "~": 96 return lambda i2b: ~kvals[0](i2b) 97 elif self.value == "!": 98 return lambda i2b: int(kvals[0](i2b) == 0) 99 elif self.value in ["+", "id"]: 100 return lambda i2b: kvals[0](i2b) 101 else: 102 raise self.EvaluationError('Internal error compiling unary "%s"' % self.value) 103 104 def _evaluate_power(self, base, exp): 105 "Raise one integer to the power of another." 106 if exp < 0: 107 raise self.EvaluationError("Negative powers (%d) are not allowed" % exp) 108 return base**exp 109 110 def _compile_arith(self, kvals): 111 "Compile an arithmetic expression." 112 if self.value == "+": 113 return lambda i2b: kvals[0](i2b) + kvals[1](i2b) 114 elif self.value == "-": 115 return lambda i2b: kvals[0](i2b) - kvals[1](i2b) 116 elif self.value == "*": 117 return lambda i2b: kvals[0](i2b) * kvals[1](i2b) 118 elif self.value == "/": 119 return lambda i2b: kvals[0](i2b) // kvals[1](i2b) 120 elif self.value == "%": 121 return lambda i2b: kvals[0](i2b) % kvals[1](i2b) 122 elif self.value == "&": 123 return lambda i2b: kvals[0](i2b) & kvals[1](i2b) 124 elif self.value == "|": 125 return lambda i2b: kvals[0](i2b) | kvals[1](i2b) 126 elif self.value == "^": 127 return lambda i2b: kvals[0](i2b) ^ kvals[1](i2b) 128 elif self.value == "<<": 129 return lambda i2b: kvals[0](i2b) << kvals[1](i2b) 130 elif self.value == ">>": 131 return lambda i2b: kvals[0](i2b) >> kvals[1](i2b) 132 elif self.value == "**": 133 return lambda i2b: self._evaluate_power(kvals[0](i2b), kvals[1](i2b)) 134 else: 135 raise self.EvaluationError("Internal error compiling arithmetic operator %s" % self.value) 136 137 def _compile_rel(self, kvals): 138 "Compile a relational expression." 139 if self.value == "=": 140 return lambda i2b: kvals[0](i2b) == kvals[1](i2b) 141 elif self.value == "/=": 142 return lambda i2b: kvals[0](i2b) != kvals[1](i2b) 143 elif self.value == "<": 144 return lambda i2b: kvals[0](i2b) < kvals[1](i2b) 145 elif self.value == "<=": 146 return lambda i2b: kvals[0](i2b) <= kvals[1](i2b) 147 elif self.value == ">": 148 return lambda i2b: kvals[0](i2b) > kvals[1](i2b) 149 elif self.value == ">=": 150 return lambda i2b: kvals[0](i2b) >= kvals[1](i2b) 151 else: 152 raise self.EvaluationError("Internal error compiling relational operator %s" % self.value) 153 154 def _compile_conn(self, kvals): 155 "Compile a logical connective." 156 if self.value == "&&": 157 return lambda i2b: kvals[0](i2b) and kvals[1](i2b) 158 elif self.value == "||": 159 return lambda i2b: kvals[0](i2b) or kvals[1](i2b) 160 else: 161 raise self.EvaluationError("Internal error compiling logical connective %s" % self.value) 162 163 def _evaluate_if_expr(self, i2b, kvals): 164 if kvals[0](i2b): 165 return kvals[1](i2b) 166 else: 167 return kvals[2](i2b) 168 169 def _compile_if_expr(self, kvals): 170 "Compile an if...then...else expression." 171 return lambda i2b: self._evaluate_if_expr(i2b, kvals) 172 173 def _compile_node(self): 174 175 """Compile the AST to a function that returns either True or False 176 given a mapping from identifiers to bits.""" 177 kvals = [k._compile_node() for k in self.kids] 178 if self.type == "ident": 179 # Variable 180 return lambda i2b: self._evaluate_ident(i2b) 181 elif self.type == "int": 182 # Constant 183 return lambda i2b: self.value 184 elif self.type == "unary": 185 # Unary expression 186 return self._compile_unary(kvals) 187 elif len(kvals) == 1: 188 # All other single-child nodes return their child unmodified. 189 return kvals[0] 190 elif self.type in ["power", "term", "expr"]: 191 return self._compile_arith(kvals) 192 elif self.type == "rel": 193 return self._compile_rel(kvals) 194 elif self.type == "conn": 195 return self._compile_conn(kvals) 196 elif self.type == "if_expr": 197 return self._compile_if_expr(kvals) 198 else: 199 raise self.EvaluationError("Internal error compiling AST node of type %s, value %s" % (repr(self.type), repr(self.value))) 200 201 def compile(self): 202 "Compile an AST for faster evaluation." 203 self.code = self._compile_node() 204 205 def evaluate(self, i2b): 206 "Evaluate the AST to a value, given a mapping from identifiers to bits." 207 try: 208 return self.code(i2b) 209 except self.EvaluationError as e: 210 self.qmasm.abend("%s in assertion %s" % (e, self)) 211 212class AssertParser(object): 213 int_re = re.compile(r'\d+') 214 conn_re = re.compile(r'\|\||&&') 215 rel_re = re.compile(r'/?=|[<>]=?') 216 arith_re = re.compile(r'[-+/%&\|^~!]|>>|<<|\*\*?') 217 keyword_re = re.compile(r'\b(if|then|else|endif)\b') 218 219 def __init__(self, qmasm): 220 self.qmasm = qmasm 221 222 class ParseError(Exception): 223 pass 224 225 def lex(self, s): 226 "Split a string into tokens (tuples of type and value)." 227 tokens = [] 228 s = s.lstrip() 229 while len(s) > 0: 230 # Match parentheses. 231 if s[0] == "(": 232 tokens.append(("lparen", "(")) 233 s = s[1:].lstrip() 234 continue 235 if s[0] == ")": 236 tokens.append(("rparen", ")")) 237 s = s[1:].lstrip() 238 continue 239 240 # Match keywords. 241 mo = self.keyword_re.match(s) 242 if mo != None: 243 match = mo.group(0) 244 tokens.append((match, match)) 245 s = s[len(match):].lstrip() 246 continue 247 248 # Match positive integers. 249 mo = self.int_re.match(s) 250 if mo != None: 251 match = mo.group(0) 252 tokens.append(("int", int(match))) 253 s = s[len(match):].lstrip() 254 continue 255 256 # Match connectives. 257 mo = self.conn_re.match(s) 258 if mo != None: 259 match = mo.group(0) 260 tokens.append(("conn", match)) 261 s = s[len(match):].lstrip() 262 continue 263 264 # Match "<<" and ">>" before we match "<" and ">". 265 if len(s) >= 2 and (s[:2] == "<<" or s[:2] == ">>"): 266 tokens.append(("arith", s[:2])) 267 s = s[2:].lstrip() 268 continue 269 270 # Match relational operators. 271 mo = self.rel_re.match(s) 272 if mo != None: 273 match = mo.group(0) 274 tokens.append(("rel", match)) 275 s = s[len(match):].lstrip() 276 continue 277 278 # Match "**" before we match "*". 279 if len(s) >= 2 and s[:2] == "**": 280 tokens.append(("power", s[:2])) 281 s = s[2:].lstrip() 282 continue 283 284 # Match arithmetic operators. 285 mo = self.arith_re.match(s) 286 if mo != None: 287 match = mo.group(0) 288 tokens.append(("arith", match)) 289 s = s[len(match):].lstrip() 290 continue 291 292 # Everything else is an identifier. 293 mo = self.qmasm.ident_re.match(s) 294 if mo != None: 295 match = mo.group(0) 296 tokens.append(("ident", match)) 297 s = s[len(match):].lstrip() 298 continue 299 raise self.ParseError("Failed to parse %s" % s) 300 tokens.append(("EOF", "EOF")) 301 return tokens 302 303 def advance(self): 304 "Advance to the next symbol." 305 self.tokidx += 1 306 self.sym = self.tokens[self.tokidx] 307 308 def accept(self, ty): 309 """Advance to the next token if the current token matches a given 310 token type and return True. Otherwise, return False.""" 311 if self.sym[0] == ty: 312 self.advance() 313 return True 314 return False 315 316 def expect(self, ty): 317 """Advance to the next token if the current token matches a given 318 token. Otherwise, fail.""" 319 if not self.accept(ty): 320 raise self.ParseError("Expected %s but saw %s" % (ty, repr(self.sym[1]))) 321 322 def generic_operator(self, return_type, child_method, sym_type, valid_ops): 323 "Match one or more somethings to produce something else." 324 # Produce a list of ASTs representing children. 325 c = child_method() 326 ops = [self.sym[1]] 327 asts = [c] 328 while self.sym[0] == sym_type and ops[-1] in valid_ops: 329 self.advance() 330 c = child_method() 331 ops.append(self.sym[1]) 332 asts.append(c) 333 334 # Handle the trivial case of the identity operation. 335 if len(asts) == 1: 336 return AST(self.qmasm, return_type, None, asts) 337 338 # Merge the ASTs in a left-associative fashion into a single AST. 339 ops.pop() 340 while len(asts) > 1: 341 asts = [AST(self.qmasm, return_type, ops[0], [asts[0], asts[1]])] + asts[2:] 342 ops.pop(0) 343 return asts[0] 344 345 def if_expr(self): 346 "Return an if...then...else expression." 347 self.expect("if") 348 cond = self.conjunction() 349 self.expect("then") 350 then_expr = self.expression() 351 self.expect("else") 352 else_expr = self.expression() 353 self.expect("endif") 354 return AST(self.qmasm, "if_expr", None, [cond, then_expr, else_expr]) 355 356 def factor(self): 357 "Return a factor (variable, integer, or expression)." 358 val = self.sym[1] 359 if self.accept("ident"): 360 child = AST(self.qmasm, "ident", val) 361 elif self.accept("int"): 362 child = AST(self.qmasm, "int", val) 363 elif self.accept("lparen"): 364 child = self.disjunction() 365 self.expect("rparen") 366 elif self.sym[0] == "arith": 367 child = self.unary() 368 elif self.sym[0] == "if": 369 child = self.if_expr() 370 elif val == "EOF": 371 raise self.ParseError("Parse error at end of expression") 372 else: 373 raise self.ParseError('Parse error at "%s"' % val) 374 return AST(self.qmasm, "factor", None, [child]) 375 376 def power(self): 377 "Return a factor or a factor raised to the power of a second factor." 378 f1 = self.factor() 379 op = self.sym[1] 380 if self.sym[0] == "power" and op == "**": 381 self.advance() 382 f2 = self.power() 383 return AST(self.qmasm, "power", op, [f1, f2]) 384 return AST(self.qmasm, "power", None, [f1]) 385 386 def unary(self): 387 "Return a unary operator applied to a power." 388 op = self.sym[1] 389 if op in ["+", "-", "~", "!"]: 390 self.advance() 391 else: 392 op = "id" 393 return AST(self.qmasm, "unary", op, [self.power()]) 394 395 def term(self): 396 "Return a term (product of one or more unaries)." 397 return self.generic_operator("term", self.unary, "arith", ["*", "/", "%", "&", "<<", ">>"]) 398 399 def expression(self): 400 "Return an expression (sum of one or more terms)." 401 return self.generic_operator("expr", self.term, "arith", ["+", "-", "|", "^"]) 402 403 def comparison(self): 404 "Return a comparison of exactly two expressions." 405 e1 = self.expression() 406 op = self.sym[1] 407 if self.sym[0] != "rel": 408 return AST(self.qmasm, "rel", None, [e1]) 409 self.advance() 410 e2 = self.expression() 411 return AST(self.qmasm, "rel", op, [e1, e2]) 412 413 def conjunction(self): 414 "Return a conjunction (logical AND of one or more comparisons)." 415 return self.generic_operator("conn", self.comparison, "conn", ["&&"]) 416 417 def disjunction(self): 418 "Return a disjunction (logical OR of one or more conjunctions)." 419 return self.generic_operator("conn", self.conjunction, "conn", ["||"]) 420 421 def parse(self, s): 422 "Parse a relational expression into an AST" 423 self.tokens = self.lex(s) 424 self.tokidx = -1 425 self.advance() 426 try: 427 ast = self.disjunction() 428 if self.sym[0] != "EOF": 429 raise self.ParseError('Parse error at "%s"' % self.sym[1]) 430 except self.ParseError as e: 431 qmasm.abend('%s in "%s"' % (e, s)) 432 return ast 433