1###################################
2# Parse a QMASM source file       #
3# By Scott Pakin <pakin@lanl.gov> #
4###################################
5
6import copy
7import os
8import qmasm
9import re
10import random
11import string
12import sys
13from collections import defaultdict
14from qmasm.assertions import AssertParser
15from qmasm.utils import RemainingNextException
16
17# Define a function that aborts the program, reporting an invalid
18# input line as part of the error message.
19def error_in_line(filename, lineno, str):
20    sys.stderr.write('%s:%d: error: %s\n' % (filename, lineno, str))
21    sys.exit(1)
22
23class Environment(object):
24    "Maintain a variable environment as a stack of scopes."
25
26    toks_re = re.compile(r'([^-+*/%&\|^~()<=>#,\s\[:\]]+)')  # Regex to split a symbol into tokens (cf. qmasm.ident_re with square brackets and colons added)
27
28    def __init__(self):
29        self.stack = [{}]
30        self.self_copy = None   # Regenerate a copy of self if equal to None
31
32    def __getitem__(self, key):
33        "Search each scope in turn for the given key."
34        for i in range(len(self.stack) - 1, -1, -1):
35            try:
36                return self.stack[i][key]
37            except KeyError:
38                pass
39        raise KeyError(key)
40
41    def __setitem__(self, key, val):
42        self.stack[-1][key] = val
43        self.self_copy = None
44
45    def push(self):
46        "Push a new scope on the environment stack."
47        self.stack.append({})
48        self.self_copy = None
49
50    def pop(self):
51        "Discard the top of the environment stack."
52        self.stack.pop()
53        self.self_copy = None
54
55    def keys(self):
56        "Return all keys in all scopes."
57        d = set()
58        for s in self.stack:
59            d.update(s.keys())
60        return list(d)
61
62    def __copy__(self):
63        "Return an independent copy of the environment stack."
64        if self.self_copy == None:
65            self.self_copy = copy.deepcopy(self)
66        return self.self_copy
67
68    def sub_syms(self, sym):
69        """Substitute values for variables encountered in a given symbol name
70        or list of symbol names."""
71        if isinstance(sym, (list,)):
72            # Recursively process each list element in turn.
73            return [self.sub_syms(s) for s in sym]
74        elif isinstance(sym, (str,)):
75            # Process a single word.
76            toks = self.toks_re.split(sym)
77            for i in range(len(toks)):
78                try:
79                    toks[i] = str(self[toks[i]])
80                except KeyError:
81                    pass
82            return "".join(toks)
83        else:
84            # Non-strings are returned unmodified.
85            return sym
86
87# I'm too lazy to write another parser so I'll simply define an
88# alternative entry point to the assertion parser.
89class ExprParser(AssertParser):
90    "Parse an arithmetic expression."
91
92    def parse(self, filename, lineno, s):
93        "Parse an arithmetic expression into an AST"
94        self.tokens = self.lex(s)
95        self.tokidx = -1
96        self.advance()
97        try:
98            ast = self.expression()
99            if self.sym[0] != "EOF":
100                raise self.ParseError('Parse error at "%s"' % self.sym[1])
101        except self.ParseError as e:
102            sys.stderr.write('%s:%d: error: %s in "%s"\n' % (filename, lineno, e, s))
103            sys.exit(1)
104        return ast
105
106# I'm too lazy to write another parser so I'll simply define an
107# alternative entry point to the assertion parser.
108class RelationParser(AssertParser):
109    "Parse a relational expression."
110
111    def parse(self, filename, lineno, s):
112        "Parse a relational expression into an AST"
113        self.tokens = self.lex(s)
114        self.tokidx = -1
115        self.advance()
116        try:
117            ast = self.conjunction()
118            if self.sym[0] != "EOF":
119                raise self.ParseError('Parse error at "%s"' % self.sym[1])
120        except self.ParseError as e:
121            sys.stderr.write('%s:%d: error: %s in "%s"\n' % (filename, lineno, e, s))
122            sys.exit(1)
123        return ast
124
125class LoopIterator(object):
126    '''Iterate over arithmetic and geometric integer sequences.  It is
127    assumed that the penultimate element of the input list is the string
128    "...".'''
129
130    def __init__(self, filename, lineno, rhs):
131        # Prepare the iterations we intend to perform.
132        self.first_val = rhs[0]
133        last_val = rhs[-1]
134        if len(rhs) == 3:
135            # "<x_0> ... <x_n>" indicates an arithmetic progression with a
136            # delta of +/- 1.
137            if self.first_val <= last_val:
138                self.increment = lambda x: x + 1
139            else:
140                self.increment = lambda x: x - 1
141        else:
142            # For "<x_0> <x_1> <x_2> ... <x_n>", compute the progression
143            # type and increment.
144            adeltas = set([rhs[i + 1] - rhs[i] for i in range(len(rhs) - 3)])
145            if len(adeltas) == 1:
146                # Arithmetic progression
147                delta = list(adeltas)[0]
148                self.increment = lambda x: x + delta
149            else:
150                # Geometric progression
151                try:
152                    mdeltas = set([rhs[i + 1] // rhs[i] for i in range(len(rhs) - 3)])
153                except ZeroDivisionError:
154                    mdeltas = set(["multiple", "values"])  # Force the next test to fail.
155                if len(mdeltas) != 1:
156                    error_in_line(filename, lineno,
157                                  'Failed to interpret "%s, %d" as either an arithmetic or geometric progression' %
158                                  (", ".join([str(r) for r in rhs[:-2]]), last_val))
159                delta = list(mdeltas)[0]
160                if delta == 0:
161                    error_in_line(filename, lineno, "Decreasing geometric progressions are not currently supported")
162                self.increment = lambda x: x * delta
163
164        # Determine when to stop.
165        second_val = self.increment(self.first_val)
166        if self.first_val < second_val:
167            self.finished = lambda x: x > last_val
168        else:
169            self.finished = lambda x: x < last_val
170
171    def __iter__(self):
172        self.next_val = self.first_val
173        return self
174
175    def __next__(self):
176        if self.finished(self.next_val):
177            raise StopIteration
178        head = self.next_val
179        self.next_val = self.increment(head)
180        return head
181
182    def next(self):
183        # Python 2 wrapper for __next__
184        return self.__next__()
185
186class Statement(object):
187    "One statement in a QMASM source file."
188
189    def __init__(self, qmasm, filename, lineno, as_qubo):
190        self.qmasm = qmasm
191        self.filename = filename
192        self.lineno = lineno
193        self.as_qubo = as_qubo
194
195    def error_in_line(self, msg):
196        if self.lineno == None:
197            self.qmasm.abend(msg)
198        else:
199            sys.stderr.write('%s:%d: error: %s\n' % (self.filename, self.lineno, msg))
200        sys.exit(1)
201
202    def validate_ident(self, ident):
203        """Complain if an identifier uses invalid symbols.  Otherwise, return
204        the argument unmodified."""
205        match = self.qmasm.ident_re.match(ident)
206        if match == None or match.group(0) != ident:
207            self.error_in_line('Invalid identifier "%s"' % ident)
208        return ident
209
210class Weight(Statement):
211    "Represent a point weight on a qubit."
212    def __init__(self, qmasm, filename, lineno, as_qubo, sym, weight):
213        super(Weight, self).__init__(qmasm, filename, lineno, as_qubo)
214        self.sym = self.validate_ident(sym)
215        self.weight = weight
216
217    def as_str(self, prefix=""):
218        return "%s%s %s" % (prefix, self.sym, self.weight)
219
220    def update_qmi(self, prefix, next_prefix, problem):
221        num = self.qmasm.symbol_to_number(prefix + self.sym, prefix, next_prefix)
222        if self.as_qubo:
223            problem.weights[num] += self.weight/2.0
224        else:
225            problem.weights[num] += self.weight
226
227class Chain(Statement):
228    "Chain between qubits."
229    def __init__(self, qmasm, filename, lineno, as_qubo, sym1, sym2):
230        super(Chain, self).__init__(qmasm, filename, lineno, as_qubo)
231        self.sym1 = self.validate_ident(sym1)
232        self.sym2 = self.validate_ident(sym2)
233
234    def as_str(self, prefix=""):
235        return "%s%s = %s%s" % (prefix, self.sym1, prefix, self.sym2)
236
237    def update_qmi(self, prefix, next_prefix, problem):
238        num1 = self.qmasm.symbol_to_number(prefix + self.sym1, prefix, next_prefix)
239        num2 = self.qmasm.symbol_to_number(prefix + self.sym2, prefix, next_prefix)
240        if num1 == num2:
241            self.error_in_line("A chain cannot connect a spin to itself")
242        elif num1 > num2:
243            num1, num2 = num2, num1
244        problem.chains.add((num1, num2))
245        sym1 = self.qmasm.apply_prefix(prefix + self.sym1, None, next_prefix)
246        sym2 = self.qmasm.apply_prefix(prefix + self.sym2, None, next_prefix)
247        problem.pending_asserts.append((sym1, "=", sym2))
248
249class AntiChain(Statement):
250    "AntiChain between qubits."
251    def __init__(self, qmasm, filename, lineno, as_qubo, sym1, sym2):
252        super(AntiChain, self).__init__(qmasm, filename, lineno, as_qubo)
253        self.sym1 = self.validate_ident(sym1)
254        self.sym2 = self.validate_ident(sym2)
255
256    def as_str(self, prefix=""):
257        return "%s%s /= %s%s" % (prefix, self.sym1, prefix, self.sym2)
258
259    def update_qmi(self, prefix, next_prefix, problem):
260        num1 = self.qmasm.symbol_to_number(prefix + self.sym1, prefix, next_prefix)
261        num2 = self.qmasm.symbol_to_number(prefix + self.sym2, prefix, next_prefix)
262        if num1 == num2:
263            self.error_in_line("An anti-chain cannot connect a spin to itself")
264        elif num1 > num2:
265            num1, num2 = num2, num1
266        problem.antichains.add((num1, num2))
267        sym1 = self.qmasm.apply_prefix(prefix + self.sym1, None, next_prefix)
268        sym2 = self.qmasm.apply_prefix(prefix + self.sym2, None, next_prefix)
269        problem.pending_asserts.append((sym1, "/=", sym2))
270
271class Pin(Statement):
272    "Pinning of a qubit to true or false."
273    def __init__(self, qmasm, filename, lineno, as_qubo, sym, goal):
274        super(Pin, self).__init__(qmasm, filename, lineno, as_qubo)
275        self.sym = self.validate_ident(sym)
276        self.goal = goal
277
278    def as_str(self, prefix=""):
279        return "%s%s := %s" % (prefix, self.sym, self.goal)
280
281    def update_qmi(self, prefix, next_prefix, problem):
282        num = self.qmasm.symbol_to_number(prefix + self.sym, prefix, next_prefix)
283        problem.pinned.append((num, self.goal))
284        sym = self.qmasm.apply_prefix(prefix + self.sym, None, next_prefix)
285        problem.pending_asserts.append((sym, "=", str(int(self.goal))))
286
287class Alias(Statement):
288    "Alias one symbol to another."
289    def __init__(self, qmasm, filename, lineno, as_qubo, sym1, sym2):
290        super(Alias, self).__init__(qmasm, filename, lineno, as_qubo)
291        self.sym1 = self.validate_ident(sym1)
292        self.sym2 = self.validate_ident(sym2)
293
294    def as_str(self, prefix=""):
295        return "%s%s <-> %s%s" % (prefix, self.sym1, prefix, self.sym2)
296
297    def update_qmi(self, prefix, next_prefix, problem):
298        sym1 = prefix + self.sym1
299        sym2 = prefix + self.sym2
300        if next_prefix != None:
301            sym1 = sym1.replace(prefix + "!next.", next_prefix)
302            sym2 = sym2.replace(prefix + "!next.", next_prefix)
303        self.qmasm.sym_map.alias(sym1, sym2)
304
305class BQMType(Statement):
306    "Set the BQM mode to either Ising or QUBO."
307    def __init__(self, qmasm, filename, lineno, as_qubo):
308        super(BQMType, self).__init__(qmasm, filename, lineno, as_qubo)
309
310    def as_str(self, prefix=""):
311        if self.as_qubo:
312            return "!bqm_type qubo"
313        else:
314            return "!bqm_type ising"
315
316    def update_qmi(self, prefix, next_prefix, problem):
317        pass
318
319class Rename(Statement):
320    "Rename one set of symbols to another."
321    def __init__(self, qmasm, filename, lineno, as_qubo, syms1, syms2):
322        super(Rename, self).__init__(qmasm, filename, lineno, as_qubo)
323        self.syms1 = [self.validate_ident(s) for s in syms1]
324        self.syms2 = [self.validate_ident(s) for s in syms2]
325
326    def as_str(self, prefix=""):
327        return " ".join([prefix + s for s in self.syms1]) + " -> " + " ".join([prefix + s for s in self.syms2])
328
329    def update_qmi(self, prefix, next_prefix, problem):
330        # Update the symbol map.
331        syms1 = [prefix + s for s in self.syms1]
332        syms2 = [prefix + s for s in self.syms2]
333        if next_prefix != None:
334            syms1 = [s.replace(prefix + "!next.", next_prefix) for s in syms1]
335            syms2 = [s.replace(prefix + "!next.", next_prefix) for s in syms2]
336        self.qmasm.sym_map.replace_all(syms1, syms2)
337        sym2sym = dict(zip(syms1, syms2))
338
339        # Update all weights.
340        weights = {}
341        for q, wt in problem.weights.items():
342            try:
343                q = sym2sym[q]
344            except KeyError:
345                pass
346            weights[q] = wt
347        problem.weights = defaultdict(lambda: 0.0, weights)
348
349        # Update all strengths.
350        strengths = {}
351        for (q1, q2), wt in problem.strengths.items():
352            try:
353                q1 = sym2sym[q1]
354            except KeyError:
355                pass
356            try:
357                q2 = sym2sym[q2]
358            except KeyError:
359                pass
360            strengths[(q1, q2)] = wt
361        problem.strengths = defaultdict(lambda: 0.0, strengths)
362
363        # Update all chains.
364        chains = set()
365        for (q1, q2) in problem.chains:
366            try:
367                q1 = sym2sym[q1]
368            except KeyError:
369                pass
370            try:
371                q2 = sym2sym[q2]
372            except KeyError:
373                pass
374            chains.add((q1, q2))
375        problem.chains = chains
376
377        # Update all anti-chains.
378        antichains = set()
379        for (q1, q2) in problem.antichains:
380            try:
381                q1 = sym2sym[q1]
382            except KeyError:
383                pass
384            try:
385                q2 = sym2sym[q2]
386            except KeyError:
387                pass
388            antichains.add((q1, q2))
389        problem.antichains = antichains
390
391        # Update all assertions.  These need to go through an intermediary in
392        # case we rename both X to Y and Y to X.
393        renames = []
394        for s1, s2 in sym2sym.items():
395            dummy_sym = "".join([random.choice(string.ascii_lowercase) for i in range(5)])
396            dummy_sym += " "   # Can't currently appear in a symbol name.
397            dummy_sym += "".join([random.choice(string.ascii_lowercase) for i in range(5)])
398            renames.append((s1, dummy_sym, s2))
399        for s1, dummy_sym, s2 in renames:
400            for ast in problem.assertions:
401                ast.replace_ident(s1, dummy_sym)
402        for s1, dummy_sym, s2 in renames:
403            for ast in problem.assertions:
404                ast.replace_ident(dummy_sym, s2)
405
406        # Update all pending assertions.
407        pending_asserts = []
408        for s1, op, s2 in problem.pending_asserts:
409            try:
410                s1 = sym2sym[s1]
411            except KeyError:
412                pass
413            try:
414                s2 = sym2sym[s2]
415            except KeyError:
416                pass
417            pending_asserts.append((s1, op, s2))
418        problem.pending_asserts = pending_asserts
419
420class Strength(Statement):
421    "Coupler strength between two qubits."
422    def __init__(self, qmasm, filename, lineno, as_qubo, sym1, sym2, strength):
423        super(Strength, self).__init__(qmasm, filename, lineno, as_qubo)
424        self.sym1 = self.validate_ident(sym1)
425        self.sym2 = self.validate_ident(sym2)
426        self.strength = strength
427
428    def as_str(self, prefix=""):
429        return "%s%s %s%s %s" % (prefix, self.sym1, prefix, self.sym2, self.strength)
430
431    def update_qmi(self, prefix, next_prefix, problem):
432        num1 = self.qmasm.symbol_to_number(prefix + self.sym1, prefix, next_prefix)
433        num2 = self.qmasm.symbol_to_number(prefix + self.sym2, prefix, next_prefix)
434        if num1 == num2:
435            self.error_in_line("A coupler cannot connect a spin to itself")
436        elif num1 > num2:
437            num1, num2 = num2, num1
438        if self.as_qubo:
439            s4 = self.strength/4.0
440            problem.strengths[(num1, num2)] += s4
441            problem.weights[num1] += s4
442            problem.weights[num2] += s4
443        else:
444            problem.strengths[(num1, num2)] += self.strength
445
446class Assert(Statement):
447    "Instantiation of a run-time assertion."
448
449    def __init__(self, qmasm, filename, lineno, as_qubo, expr):
450        super(Assert, self).__init__(qmasm, filename, lineno, as_qubo)
451        self.parser = AssertParser(qmasm)
452        self.expr = expr
453        self.ast = self.parser.parse(expr)
454        self.ast.compile()
455
456    def as_str(self, prefix=""):
457        if prefix == "":
458            ast = self.ast
459        else:
460            ast = copy.deepcopy(self.ast)
461            ast.prefix_identifiers(prefix, None)
462            ast.compile()
463        return "!assert " + str(ast)
464
465    def update_qmi(self, prefix, next_prefix, problem):
466        if prefix == "":
467            ast = self.ast
468        else:
469            ast = copy.deepcopy(self.ast)
470            ast.prefix_identifiers(prefix, next_prefix)
471            ast.compile()
472        problem.assertions.append(ast)
473
474class MacroUse(Statement):
475    "Instantiation of a macro definition."
476    def __init__(self, qmasm, filename, lineno, as_qubo, name, body, prefixes):
477        super(MacroUse, self).__init__(qmasm, filename, lineno, as_qubo)
478        self.name = self.validate_ident(name)
479        self.body = body
480        self.prefixes = [self.validate_ident(p) + "." for p in prefixes]
481
482    def as_str(self, prefix=""):
483        stmt_strs = []
484        nprefixes = len(self.prefixes)
485        if nprefixes == 0:
486            # No prefixes -- display the macro body in the current scope (i.e.,
487            # using the given prefix).
488            for stmt in self.body:
489                sstr = stmt.as_str(prefix)
490                if "!next." not in sstr:
491                    stmt_strs.append(sstr)
492        else:
493            # At least one prefix -- display the macro body in a new scope
494            # (i.e., by augmenting the given prefix with each new prefix in
495            # turn).
496            for p in range(nprefixes):
497                pfx = self.prefixes[p]
498                for stmt in self.body:
499                    sstr = stmt.as_str(prefix + pfx)
500                    if "!next." in sstr:
501                        if p == nprefixes - 1:
502                            # Drop statements that use "!next." if there's
503                            # no next prefix.
504                            continue
505                        next_pfx = self.prefixes[p + 1]
506                        sstr = sstr.replace(prefix + pfx + "!next.", prefix + next_pfx)
507                    stmt_strs.append(sstr)
508        return "\n".join(stmt_strs)
509
510    def update_qmi(self, prefix, next_prefix, problem):
511        nprefixes = len(self.prefixes)
512        if nprefixes == 0:
513            # No prefixes -- import the macro body into the current scope
514            # (i.e., using the given prefix).
515            for stmt in self.body:
516                try:
517                    stmt.update_qmi(prefix, None, problem)
518                except self.qmasm.utils.RemainingNextException:
519                    pass
520        else:
521            # At least one prefix -- import the macro body into a new scope
522            # (i.e., by augmenting the given prefix with each new prefix in
523            # turn).
524            for p in range(nprefixes):
525                pfx = prefix + self.prefixes[p]
526                if p == nprefixes - 1:
527                    next_pfx = None
528                else:
529                    next_pfx = prefix + self.prefixes[p + 1]
530                for stmt in self.body:
531                    try:
532                        stmt.update_qmi(pfx, next_pfx, problem)
533                    except RemainingNextException:
534                        pass
535
536class FileParser(object):
537    "Parse a QMASM file."
538
539    def __init__(self, qmasm):
540        self.qmasm = qmasm      # Reference to the object we're mixed into
541        self.macros = {}        # Map from a macro name to a list of Statement objects
542        self.current_macro = (None, [])   # Macro currently being defined (name and statements)
543        self.target = qmasm.program   # Reference to either the program or the current macro
544        self.env = Environment()      # Stack of maps from compile-time variable names to values
545        self.expr_parser = ExprParser(qmasm)     # Expression parser
546        self.rel_parser = RelationParser(qmasm)  # Relation parser
547        self._as_qubo = False   # Current BQM mode (QUBO or Ising)
548
549        # Establish a mapping from a first-field directive to a parsing function.
550        self.dir_to_func = {
551            "!include":     self.parse_line_include,
552            "!assert":      self.parse_line_assert,
553            "!let":         self.parse_line_let,
554            "!begin_macro": self.parse_line_begin_macro,
555            "!end_macro":   self.parse_line_end_macro,
556            "!use_macro":   self.parse_line_use_macro,
557            "!alias":       self.parse_line_sym_alias,
558            "!bqm_type":    self.parse_line_bqm_type
559        }
560
561    def is_float(self, str):
562        "Return True if a string can be treated as a float."
563        try:
564            float(str)
565            return True
566        except ValueError:
567            return False
568
569    def split_line_into_fields(self, line):
570        "Discard comments, then split the line into fields."
571        try:
572            line = line[:line.index("#")]
573        except ValueError:
574            pass
575        return line.split()
576
577    def find_file_in_path(self, pathnames, filename):
578        "Search a list of directories for a file."
579        for pname in pathnames:
580            fname = os.path.join(pname, filename)
581            if os.path.exists(fname):
582                return fname
583            fname_qmasm = fname + ".qmasm"
584            if os.path.exists(fname_qmasm):
585                return fname_qmasm
586        return None
587
588    def parse_line_include(self, filename, lineno, fields):
589        "Parse an !include directive."
590        # "!include" "<filename>" -- process a named auxiliary file.
591        if len(fields) != 2:
592            error_in_line(filename, lineno, "Expected a filename to follow !include")
593        incname = " ".join(fields[1:])
594        if len(incname) >= 2 and incname[0] == "<" and incname[-1] == ">":
595            # Search QMASMPATH for the filename.
596            incname = incname[1:-1]
597            try:
598                qmasmpath = os.environ["QMASMPATH"].split(":")
599                qmasmpath.append(".")
600            except KeyError:
601                qmasmpath = ["."]
602            found_incname = self.find_file_in_path(qmasmpath, incname)
603            if found_incname != None:
604                incname = found_incname
605        elif len(incname) >= 2 and incname[0] == '"' and incname[-1] == '"':
606            # Search only the current directory for the filename.
607            incname = incname[1:-1]
608            found_incname = self.find_file_in_path(["."], incname)
609            if found_incname != None:
610                incname = found_incname
611        else:
612            error_in_line(filename, lineno, 'Failed to parse "%s"' % (" ".join(fields)))
613        try:
614            incfile = open(incname)
615        except IOError:
616            error_in_line(filename, lineno, 'Failed to open %s for input' % incname)
617        self.process_file(incname, incfile)
618        incfile.close()
619
620    def parse_line_assert(self, filename, lineno, fields):
621        "Parse an !assert directive."
622        # "!assert" <expr> -- assert a property that must be true at run time.
623        if len(fields) < 2:
624            error_in_line(filename, lineno, "Expected an expression to follow !assert")
625        self.target.append(Assert(self.qmasm, filename, lineno, self._as_qubo, " ".join(self.env.sub_syms(fields[1:]))))
626
627    def parse_line_let(self, filename, lineno, fields):
628        "Parse a !let directive."
629        # "!let" <name> := <expr> -- evaluate <expr> and assign the result to <name>.
630        if len(fields) < 4 or fields[2] != ":=":
631            error_in_line(filename, lineno, 'Expected a variable name, ":=", and an expression to follow !let')
632        lhs = fields[1]
633        if len(fields) == 4:
634            # Handle the case of "!let" <name> := <symbol>.
635            is_sym = self.env.toks_re.match(fields[3])
636            if is_sym != None and is_sym.group(0) == fields[3]:
637                self.env[lhs] = self.env.sub_syms(fields[3])
638                return
639        ast = self.expr_parser.parse(filename, lineno, " ".join(self.env.sub_syms(fields[3:])))
640        ast.compile()
641        rhs = ast.evaluate(dict(self.env))
642        self.env[lhs] = rhs
643
644    def parse_line_begin_macro(self, filename, lineno, fields):
645        "Parse a !begin_macro directive."
646        # "!begin_macro" <name> -- begin a macro definition.
647        if len(fields) != 2:
648            error_in_line(filename, lineno, "Expected a macro name to follow !begin_macro")
649        name = fields[1]
650        if name in self.macros:
651            error_in_line(self, filename, lineno, "Macro %s is multiply defined" % name)
652        if self.current_macro[0] != None:
653            error_in_line(filename, lineno, "Nested macros are not supported")
654        self.current_macro = (name, [])
655        self.target = self.current_macro[1]
656        self.env.push()
657
658    def parse_line_end_macro(self, filename, lineno, fields):
659        "Parse an !end_macro directive."
660        # "!end_macro" <name> -- end a macro definition.
661        if len(fields) != 2:
662            error_in_line(filename, lineno, "Expected a macro name to follow !end_macro")
663        name = fields[1]
664        if self.current_macro[0] == None:
665            error_in_line(filename, lineno, "Ended macro %s with no corresponding begin" % name)
666        if self.current_macro[0] != name:
667            error_in_line(filename, lineno, "Ended macro %s after beginning macro %s" % (name, self.current_macro[0]))
668        self.macros[name] = self.current_macro[1]
669        self.target = self.qmasm.program
670        self.current_macro = (None, [])
671        self.env.pop()
672
673    def parse_line_weight(self, filename, lineno, fields):
674        "Parse a qubit weight."
675        # <symbol> <weight> -- increment a symbol's point weight.
676        if len(fields) != 2:
677            error_in_line(filename, lineno, "Internal error in parse_line_weight")
678        try:
679            val = float(self.env.sub_syms(fields[1]))
680        except ValueError:
681            error_in_line(filename, lineno, 'Failed to parse "%s %s" as a symbol followed by a numerical weight' % (fields[0], fields[1]))
682        self.target.append(Weight(self.qmasm, filename, lineno, self._as_qubo, self.env.sub_syms(fields[0]), val))
683
684    def parse_line_chain(self, filename, lineno, fields):
685        "Parse a qubit chain."
686        # <symbol_1> = <symbol_2> -- create a chain between <symbol_1>
687        # and <symbol_2>.
688        if len(fields) != 3 or fields[1] != "=":
689            error_in_line(filename, lineno, "Internal error in parse_line_chain")
690        code = "%s = %s" % (self.env.sub_syms(fields[0]), self.env.sub_syms(fields[2]))
691        self.target.extend(self.process_chain(filename, lineno, code))
692
693    def parse_line_antichain(self, filename, lineno, fields):
694        "Parse a qubit anti-chain."
695        # <symbol_1> /= <symbol_2> -- create an anti-chain between <symbol_1>
696        # and <symbol_2>.
697        if len(fields) != 3 or fields[1] != "/=":
698            error_in_line(filename, lineno, "Internal error in parse_line_antichain")
699        code = "%s /= %s" % (self.env.sub_syms(fields[0]), self.env.sub_syms(fields[2]))
700        self.target.extend(self.process_antichain(filename, lineno, code))
701
702    def parse_line_pin(self, filename, lineno, fields):
703        "Parse a qubit pin."
704        # <symbol> := <value> -- force symbol <symbol> to have value <value>.
705        if len(fields) != 3 or fields[1] != ":=":
706            error_in_line(filename, lineno, "Internal error in parse_line_pin")
707        code = "%s := %s" % (self.env.sub_syms(fields[0]), self.env.sub_syms(fields[2]))
708        self.target.extend(self.process_pin(filename, lineno, code))
709
710    def parse_line_alias(self, filename, lineno, fields):
711        "Parse a qubit alias."
712        # <symbol_1> <-> <symbol_2> -- make <symbol_1> an alias of <symbol_2>.
713        if len(fields) != 3 or fields[1] != "<->":
714            error_in_line(filename, lineno, "Internal error in parse_line_alias")
715        code = "%s <-> %s" % (self.env.sub_syms(fields[0]), self.env.sub_syms(fields[2]))
716        self.target.extend(self.process_alias(filename, lineno, code))
717
718    def parse_line_rename(self, filename, lineno, fields):
719        "Parse a qubit rename."
720        # <symbol_1> ... -> <symbol_2> ... -- make <symbol_1> an alias of <symbol_2>.
721        if len(fields) < 3 or len(fields)%2 == 0:
722            error_in_line(filename, lineno, 'Failed to parse "%s" as a symbol rename' % (" ".join(fields)))
723
724        # Split the fields into a left-hand side and a right-hand side.
725        pin_parser = PinParser()
726        tokens = []
727        num_arrows = 0
728        for sym in fields:
729            if sym == "->":
730                tokens.append(sym)
731                num_arrows += 1
732            else:
733                sym_list = pin_parser.parse_lhs(self.env.sub_syms(sym))
734                tokens.extend(sym_list)
735        lhs = tokens[:len(tokens)//2]
736        rhs = tokens[len(tokens)//2:]
737        if num_arrows != 1 or rhs[0] != "->":
738            error_in_line(filename, lineno, 'Failed to parse "%s" as a symbol rename' % (" ".join(fields)))
739        rhs = rhs[1:]  # Drop the "->".
740        self.target.append(Rename(self.qmasm, filename, lineno, self._as_qubo, lhs, rhs))
741
742    def parse_line_strength(self, filename, lineno, fields):
743        "Parse a coupler strength."
744        # <symbol_1> <symbol_2> <strength> -- increment a coupler strength.
745        if len(fields) != 3:
746            error_in_line(filename, lineno, "Internal error in parse_line_strength")
747        try:
748            strength = float(self.env.sub_syms(fields[2]))
749        except ValueError:
750            error_in_line(filename, lineno, 'Failed to parse "%s" as a number' % fields[2])
751        self.target.append(Strength(self.qmasm, filename, lineno, self._as_qubo, self.env.sub_syms(fields[0]), self.env.sub_syms(fields[1]), strength))
752
753    def parse_line_use_macro(self, filename, lineno, fields):
754        "Parse a !use_macro directive."
755        # "!use_macro" <macro_name> [<instance_name> ...] -- instantiate a
756        # macro using <instance_name> as each variable's prefix.
757        if len(fields) < 2:
758            error_in_line(filename, lineno, "Expected a macro name to follow !use_macro")
759        name = self.env.sub_syms(fields[1])
760        prefixes = [self.env.sub_syms(p) for p in fields[2:]]
761        try:
762            self.target.append(MacroUse(self.qmasm, filename, lineno, self._as_qubo, name, self.macros[name], prefixes))
763        except KeyError:
764            error_in_line(filename, lineno, "Unknown macro %s" % name)
765
766    def parse_line_sym_alias(self, filename, lineno, fields):
767        "Parse an !alias directive."
768        sys.stderr.write('%s:%d: warning: !alias is deprecated; use "!let %s := %s" instead\n' % (filename, lineno, fields[1], fields[2]))
769        if len(fields) != 3:
770            error_in_line(filename, lineno, "Expected a symbol name and replacement to follow !alias")
771        self.env[fields[1]] = self.env.sub_syms(fields[2])
772
773    def parse_line_bqm_type(self, filename, lineno, fields):
774        "Parse a !bqm_type directive."
775        if len(fields) != 2:
776            error_in_line(filename, lineno, 'Expected either "qubo" or "ising" to follow !bqm_type')
777        if fields[1] == "qubo":
778            as_qubo = True
779        elif fields[1] == "ising":
780            as_qubo = False
781        else:
782            error_in_line(filename, lineno, 'Expected either "qubo" or "ising" to follow !bqm_type')
783        self.target.append(BQMType(self.qmasm, filename, lineno, as_qubo))
784
785    def process_if(self, filename, lineno, fields, all_lines):
786        """Parse and process an !if directive.  Recursively parse the remaining
787        file contents."""
788        if len(fields) < 2:
789            error_in_line(filename, lineno, "Expected a relational expression to follow !if")
790
791        # Scan ahead for the matching !end_if.
792        ends_needed = 1
793        else_idx = -1
794        for idx in range(1, len(all_lines)):
795            # Split the line into fields and apply text aliases.
796            lno, line = all_lines[idx]
797            if line == "":
798                continue
799            flds = self.split_line_into_fields(line)
800            if flds[0] == "!if":
801                ends_needed += 1
802            elif flds[0] == "!else":
803                if len(flds) != 1:
804                    error_in_line(filename, lno, "Unexpected text after !else")
805                if else_idx != -1:
806                    error_in_line(filename, lno, "An !else matching line %d's !if already appeared in line %d" % (lineno, all_lines[else_idx][0]))
807                else_idx = idx
808            elif flds[0] == "!end_if":
809                if len(flds) != 1:
810                    error_in_line(filename, lno, "Unexpected text after !end_if")
811                ends_needed -= 1
812                if ends_needed == 0:
813                    break
814
815        # If the condition is true, parse the body.  Otherwise do nothing.
816        if ends_needed != 0:
817            error_in_line(filename, lineno, "Failed to find a matching !end_if directive")
818        end_idx = idx
819        ast = self.rel_parser.parse(filename, lineno, " ".join(self.env.sub_syms(fields[1:])))
820        ast.compile()
821        rhs = ast.evaluate(dict(self.env))
822        if rhs:
823            # Evaluate the then clause in a new scope.
824            self.env.push()
825            if else_idx == -1:
826                # No else clause
827                self.process_file_contents(filename, all_lines[1:end_idx])
828            else:
829                # else clause
830                self.process_file_contents(filename, all_lines[1:else_idx])
831            self.env.pop()
832        elif else_idx != -1:
833            # Evaluate the else clause in a new scope.
834            self.env.push()
835            self.process_file_contents(filename, all_lines[else_idx+1:end_idx])
836            self.env.pop()
837
838        # Process the rest of the file.
839        self.process_file_contents(filename, all_lines[end_idx+1:])
840
841    def process_for(self, filename, lineno, fields, all_lines):
842        """Parse and process a !for directive.  Recursively parse the remaining
843        file contents."""
844        # Parse the !for line.
845        if len(fields) < 4 or fields[2] != ":=":
846            error_in_line(filename, lineno, 'Expected a variable name, an ":=", and a comma-separated list to follow !for')
847        seq = [s.strip() for s in " ".join(self.env.sub_syms(fields[3:])).split(",")]
848        for i in range(len(seq)):
849            if seq[i] == "..." and i != len(seq) - 2:
850                error_in_line(filename, lineno, '"..." can appear only in the penultimate position in a sequence')
851
852        # Construct an iterator based on the given sequence.
853        if "..." in seq:
854            seq_ints = []
855            for i in range(len(seq)):
856                if seq[i] == "...":
857                    seq_ints.append("...")
858                else:
859                    ast = self.expr_parser.parse(filename, lineno, seq[i])
860                    ast.compile()
861                    val = ast.evaluate(dict(self.env))
862                    seq_ints.append(val)
863            iter = LoopIterator(filename, lineno, seq_ints)
864        else:
865            iter = seq
866
867        # Scan ahead for the matching !end_for.
868        ends_needed = 1
869        for idx in range(1, len(all_lines)):
870            # Split the line into fields and apply text aliases.
871            lno, line = all_lines[idx]
872            if line == "":
873                continue
874            flds = self.split_line_into_fields(line)
875            if flds[0] == "!for":
876                ends_needed += 1
877            elif flds[0] == "!end_for":
878                if len(flds) != 1:
879                    error_in_line(filename, lno, "Unexpected text after !end_for")
880                ends_needed -= 1
881                if ends_needed == 0:
882                    break
883
884        # Parse the body once per iterator element.
885        if ends_needed != 0:
886            error_in_line(filename, lineno, "Failed to find a matching !end_for directive")
887        end_idx = idx
888        for val in iter:
889            self.env.push()
890            self.env[fields[1]] = val
891            self.process_file_contents(filename, all_lines[1:end_idx])
892            self.env.pop()
893
894        # Process the rest of the file.
895        self.process_file_contents(filename, all_lines[end_idx+1:])
896
897
898    def process_pin(self, filename, lineno, pin_str):
899        "Parse a pin statement into one or more Pin objects and add these to the program."
900        lhs_rhs = pin_str.split(":=")
901        if len(lhs_rhs) != 2:
902            self.qmasm.abend('Failed to parse pin statement "%s"' % pin_str)
903        pin_parser = PinParser()
904        lhs_list = pin_parser.parse_lhs(self.env.sub_syms(lhs_rhs[0]))
905        rhs_list = pin_parser.parse_rhs(self.env.sub_syms(lhs_rhs[1]))
906        if len(lhs_list) != len(rhs_list):
907            self.qmasm.abend('Different number of left- and right-hand-side values in "%s" (%d vs. %d)' % (pin_str, len(lhs_list), len(rhs_list)))
908        return [Pin(self.qmasm, filename, lineno, self._as_qubo, l, r) for l, r in zip(lhs_list, rhs_list)]
909
910    def process_chain(self, filename, lineno, chain_str):
911        "Parse a chain statement into one or more Chain objects and add these to the program."
912        # We use the LHS parser from PinParser to parse both sides of the chain.
913        lhs_rhs = chain_str.split("=")
914        if len(lhs_rhs) != 2:
915            self.qmasm.abend('Failed to parse chain statement "%s"' % chain_str)
916        pin_parser = PinParser()
917        lhs_list = pin_parser.parse_lhs(self.env.sub_syms(lhs_rhs[0]))
918        rhs_list = pin_parser.parse_lhs(self.env.sub_syms(lhs_rhs[1]))  # Note use of parse_lhs to parse the RHS.
919        if len(lhs_list) != len(rhs_list):
920            self.qmasm.abend('Different number of left- and right-hand-side values in "%s" (%d vs. %d)' % (chain_str, len(lhs_list), len(rhs_list)))
921        return [Chain(self.qmasm, filename, lineno, self._as_qubo, l, r) for l, r in zip(lhs_list, rhs_list)]
922
923    def process_antichain(self, filename, lineno, antichain_str):
924        "Parse an anti-chain statement into one or more AntiChain objects and add these to the program."
925        # We use the LHS parser from PinParser to parse both sides of the anti-chain.
926        lhs_rhs = antichain_str.split("/=")
927        if len(lhs_rhs) != 2:
928            self.qmasm.abend('Failed to parse anti-chain statement "%s"' % antichain_str)
929        pin_parser = PinParser()
930        lhs_list = pin_parser.parse_lhs(self.env.sub_syms(lhs_rhs[0]))
931        rhs_list = pin_parser.parse_lhs(self.env.sub_syms(lhs_rhs[1]))  # Note use of parse_lhs to parse the RHS.
932        if len(lhs_list) != len(rhs_list):
933            self.qmasm.abend('Different number of left- and right-hand-side values in "%s" (%d vs. %d)' % (antichain_str, len(lhs_list), len(rhs_list)))
934        return [AntiChain(self.qmasm, filename, lineno, self._as_qubo, l, r) for l, r in zip(lhs_list, rhs_list)]
935
936    def process_alias(self, filename, lineno, alias_str):
937        "Parse an alias statement into one or more Alias objects and add these to the program."
938        # We use the LHS parser from PinParser to parse both sides of the alias.
939        lhs_rhs = alias_str.split("<->")
940        if len(lhs_rhs) != 2:
941            self.qmasm.abend('Failed to parse alias statement "%s"' % alias_str)
942        pin_parser = PinParser()
943        lhs_list = pin_parser.parse_lhs(self.env.sub_syms(lhs_rhs[0]))
944        rhs_list = pin_parser.parse_lhs(self.env.sub_syms(lhs_rhs[1]))  # Note use of parse_lhs to parse the RHS.
945        if len(lhs_list) != len(rhs_list):
946            self.qmasm.abend('Different number of left- and right-hand-side values in "%s" (%d vs. %d)' % (alias_str, len(lhs_list), len(rhs_list)))
947        return [Alias(self.qmasm, filename, lineno, self._as_qubo, l, r) for l, r in zip(lhs_list, rhs_list)]
948
949    def process_file_contents(self, filename, all_lines):
950        """Parse the contents of a file.  Contents are passed as a list plus an
951        initial line number."""
952        for idx in range(len(all_lines)):
953            # Split the line into fields and apply text aliases.
954            lineno, line = all_lines[idx]
955            if line == "":
956                continue
957            fields = self.split_line_into_fields(line)
958            nfields = len(fields)
959            if nfields == 0:
960                continue
961
962            # Process the line.
963            if fields[0] == "!if":
964                # Special case for !if directives
965                return self.process_if(filename, lineno, fields, all_lines[idx:])
966            elif fields[0] == "!for":
967                # Special case for !for directives
968                return self.process_for(filename, lineno, fields, all_lines[idx:])
969            try:
970                # Parse first-field directives.
971                func = self.dir_to_func[fields[0]]
972            except KeyError:
973                # Reject unrecognized directives.
974                if fields[0][0] == "!":
975                    if fields[0] in ["!else", "!end_if"]:
976                        error_in_line(filename, lineno, "Encountered an %s without a matching !if" % fields[0])
977                    elif fields[0] == "!end_for":
978                        error_in_line(filename, lineno, "Encountered an !end_for without a matching !for")
979                    else:
980                        error_in_line(filename, lineno, "Unrecognized directive %s" % fields[0])
981
982                # Prohibit "!next." outside of macros.
983                if self.current_macro[0] == None:
984                    for f in fields:
985                        if "!next." in f:
986                            error_in_line(filename, lineno, '"!next." is allowed only within !begin_macro...!end_macro blocks')
987
988                # Parse all lines not containing a directive in the first field.
989                if nfields == 2:
990                    func = self.parse_line_weight
991                elif nfields == 3 and fields[1] == "=":
992                    func = self.parse_line_chain
993                elif nfields == 3 and fields[1] == "/=":
994                    func = self.parse_line_antichain
995                elif nfields == 3 and fields[1] == ":=":
996                    func = self.parse_line_pin
997                elif nfields == 3 and fields[1] == "<->":
998                    func = self.parse_line_alias
999                elif nfields >= 3 and "->" in fields:
1000                    func = self.parse_line_rename
1001                elif nfields == 3 and self.is_float(fields[2]):
1002                    func = self.parse_line_strength
1003                else:
1004                    # None of the above
1005                    error_in_line(filename, lineno, 'Failed to parse "%s"' % line)
1006            func(filename, lineno, fields)
1007            if func == self.parse_line_bqm_type:
1008                self._as_qubo = self.target[-1].as_qubo
1009
1010    def process_file(self, filename, infile):
1011        """Define a function that parses an input file into an internal
1012        representation.  This function can be called recursively (due to !include
1013        directives)."""
1014
1015        # Read the entire file into a list.
1016        all_lines = []
1017        lineno = 1
1018        for line in infile:
1019            all_lines.append((lineno, line.strip()))
1020            lineno += 1
1021        self.process_file_contents(filename, all_lines)
1022
1023    def process_files(self, file_list):
1024        "Parse a list of file(s) into an internal representation."
1025        if file_list == []:
1026            # No files were specified: Read from standard input.
1027            self.process_file("<stdin>", sys.stdin)
1028            if self.current_macro[0] != None:
1029                error_in_line(filename, lineno, "Unterminated definition of macro %s" % self.current_macro[0])
1030        else:
1031            # Files were specified: Process each in turn.
1032            for infilename in file_list:
1033                try:
1034                    infile = open(infilename)
1035                except IOError:
1036                    self.qmasm.abend('Failed to open %s for input' % infilename)
1037                self.process_file(infilename, infile)
1038                if self.current_macro[0] != None:
1039                    error_in_line(filename, lineno, "Unterminated definition of macro %s" % self.current_macro[0])
1040                infile.close()
1041
1042class PinParser(object):
1043    "Provide methods for parsing a pin statement."
1044
1045    def __init__(self):
1046        self.bracket_re = re.compile(r'^\s*(\d+)(\s*(?:\.\.|:)\s*(\d+))?\s*$')
1047        self.bool_re = re.compile(r'TRUE|FALSE|T|F|0|[-+]?1', re.IGNORECASE)
1048
1049        # Define synonyms for "true" and "false".
1050        self.str2bool = {s: True for s in ["1", "+1", "T", "TRUE"]}
1051        self.str2bool.update({s: False for s in ["0", "-1", "F", "FALSE"]})
1052
1053    def expand_brackets(self, vars, expr):
1054        """Repeat one or more variables for each bracketed expression.  For
1055        example, expanding ("hello", "1 .. 3") should produce
1056        ("hello[1]", "hello[2]", "hello[3]")."""
1057        # Determine the starting and ending numbers and the step.
1058        bmatch = self.bracket_re.search(expr)
1059        if bmatch == None:
1060            return ["%s[%s]" % (v, expr) for v in vars]
1061        bmatches = bmatch.groups()
1062        num1 = int(bmatches[0])
1063        if bmatches[2] == None:
1064            num2 = num1
1065        else:
1066            num2 = int(bmatches[2])
1067        if num1 <= num2:
1068            step = 1
1069        else:
1070            step = -1
1071
1072        # Append the same bracketed constant to each variable.
1073        new_vars = []
1074        for v in vars:
1075            for i in range(num1, num2 + step, step):
1076                new_vars.append("%s[%d]" % (v, i))
1077        return new_vars
1078
1079    def parse_lhs(self, lhs):
1080        "Parse the left-hand side of a pin statement."
1081        variables = [""]
1082        group_len = 1    # Number of variables produced from the same bracketed expression
1083        bracket_expr = ""
1084        in_bracket = False
1085        for c in lhs:
1086            if c == "[":
1087                if in_bracket:
1088                    self.qmasm.abend("Nested brackets are not allowed")
1089                in_bracket = True
1090            elif c == "]":
1091                if not in_bracket:
1092                    self.qmasm.abend('Encountered "]" before seeing a "["')
1093                old_vars = variables[:-group_len]
1094                current_vars = variables[-group_len:]
1095                new_vars = self.expand_brackets(current_vars, bracket_expr)
1096                variables = old_vars + new_vars
1097                group_len = len(new_vars)
1098                in_bracket = False
1099                bracket_expr = ""
1100            elif in_bracket:
1101                bracket_expr += c
1102            elif c == " " or c == "\t":
1103                if in_bracket:
1104                    self.qmasm.abend("Unterminated bracketed expression")
1105                if variables[-1] != "":
1106                    variables.append("")
1107                group_len = 1
1108            else:
1109                for i in range(1, group_len + 1):
1110                    variables[-i] += c
1111        if in_bracket:
1112            self.qmasm.abend("Unterminated bracketed expression")
1113        if variables[-1] == "":
1114            variables.pop()
1115        return variables
1116
1117    def parse_rhs(self, rhs):
1118        "Parse the right-hand side of a pin statement."
1119        for inter in [t.strip() for t in self.bool_re.split(rhs)]:
1120            if inter != "":
1121                self.qmasm.abend('Unexpected "%s" in pin right-hand side "%s"' % (inter, rhs))
1122        return [self.str2bool[t.upper()] for t in self.bool_re.findall(rhs)]
1123