1"""An implementation of the Zephyr Abstract Syntax Definition Language.
2
3See http://asdl.sourceforge.net/ and
4http://www.cs.princeton.edu/research/techreps/TR-554-97
5
6Only supports top level module decl, not view.  I'm guessing that view
7is intended to support the browser and I'm not interested in the
8browser.
9
10Changes for Python: Add support for module versions
11"""
12
13import os
14import traceback
15
16import spark
17
18class Token(object):
19    # spark seems to dispatch in the parser based on a token's
20    # type attribute
21    def __init__(self, type, lineno):
22        self.type = type
23        self.lineno = lineno
24
25    def __str__(self):
26        return self.type
27
28    def __repr__(self):
29        return str(self)
30
31class Id(Token):
32    def __init__(self, value, lineno):
33        self.type = 'Id'
34        self.value = value
35        self.lineno = lineno
36
37    def __str__(self):
38        return self.value
39
40class String(Token):
41    def __init__(self, value, lineno):
42        self.type = 'String'
43        self.value = value
44        self.lineno = lineno
45
46class ASDLSyntaxError(Exception):
47
48    def __init__(self, lineno, token=None, msg=None):
49        self.lineno = lineno
50        self.token = token
51        self.msg = msg
52
53    def __str__(self):
54        if self.msg is None:
55            return "Error at '%s', line %d" % (self.token, self.lineno)
56        else:
57            return "%s, line %d" % (self.msg, self.lineno)
58
59class ASDLScanner(spark.GenericScanner, object):
60
61    def tokenize(self, input):
62        self.rv = []
63        self.lineno = 1
64        super(ASDLScanner, self).tokenize(input)
65        return self.rv
66
67    def t_id(self, s):
68        r"[\w\.]+"
69        # XXX doesn't distinguish upper vs. lower, which is
70        # significant for ASDL.
71        self.rv.append(Id(s, self.lineno))
72
73    def t_string(self, s):
74        r'"[^"]*"'
75        self.rv.append(String(s, self.lineno))
76
77    def t_xxx(self, s): # not sure what this production means
78        r"<="
79        self.rv.append(Token(s, self.lineno))
80
81    def t_punctuation(self, s):
82        r"[\{\}\*\=\|\(\)\,\?\:]"
83        self.rv.append(Token(s, self.lineno))
84
85    def t_comment(self, s):
86        r"\-\-[^\n]*"
87        pass
88
89    def t_newline(self, s):
90        r"\n"
91        self.lineno += 1
92
93    def t_whitespace(self, s):
94        r"[ \t]+"
95        pass
96
97    def t_default(self, s):
98        r" . +"
99        raise ValueError, "unmatched input: %s" % `s`
100
101class ASDLParser(spark.GenericParser, object):
102    def __init__(self):
103        super(ASDLParser, self).__init__("module")
104
105    def typestring(self, tok):
106        return tok.type
107
108    def error(self, tok):
109        raise ASDLSyntaxError(tok.lineno, tok)
110
111    def p_module_0(self, (module, name, version, _0, _1)):
112        " module ::= Id Id version { } "
113        if module.value != "module":
114            raise ASDLSyntaxError(module.lineno,
115                                  msg="expected 'module', found %s" % module)
116        return Module(name, None, version)
117
118    def p_module(self, (module, name, version, _0, definitions, _1)):
119        " module ::= Id Id version { definitions } "
120        if module.value != "module":
121            raise ASDLSyntaxError(module.lineno,
122                                  msg="expected 'module', found %s" % module)
123        return Module(name, definitions, version)
124
125    def p_version(self, (version, V)):
126        "version ::= Id String"
127        if version.value != "version":
128            raise ASDLSyntaxError(version.lineno,
129                                msg="expected 'version', found %s" % version)
130        return V
131
132    def p_definition_0(self, (definition,)):
133        " definitions ::= definition "
134        return definition
135
136    def p_definition_1(self, (definitions, definition)):
137        " definitions ::= definition definitions "
138        return definitions + definition
139
140    def p_definition(self, (id, _, type)):
141        " definition ::= Id = type "
142        return [Type(id, type)]
143
144    def p_type_0(self, (product,)):
145        " type ::= product "
146        return product
147
148    def p_type_1(self, (sum,)):
149        " type ::= sum "
150        return Sum(sum)
151
152    def p_type_2(self, (sum, id, _0, attributes, _1)):
153        " type ::= sum Id ( fields ) "
154        if id.value != "attributes":
155            raise ASDLSyntaxError(id.lineno,
156                                  msg="expected attributes, found %s" % id)
157        if attributes:
158            attributes.reverse()
159        return Sum(sum, attributes)
160
161    def p_product(self, (_0, fields, _1)):
162        " product ::= ( fields ) "
163        # XXX can't I just construct things in the right order?
164        fields.reverse()
165        return Product(fields)
166
167    def p_sum_0(self, (constructor,)):
168        " sum ::= constructor "
169        return [constructor]
170
171    def p_sum_1(self, (constructor, _, sum)):
172        " sum ::= constructor | sum "
173        return [constructor] + sum
174
175    def p_sum_2(self, (constructor, _, sum)):
176        " sum ::= constructor | sum "
177        return [constructor] + sum
178
179    def p_constructor_0(self, (id,)):
180        " constructor ::= Id "
181        return Constructor(id)
182
183    def p_constructor_1(self, (id, _0, fields, _1)):
184        " constructor ::= Id ( fields ) "
185        # XXX can't I just construct things in the right order?
186        fields.reverse()
187        return Constructor(id, fields)
188
189    def p_fields_0(self, (field,)):
190        " fields ::= field "
191        return [field]
192
193    def p_fields_1(self, (field, _, fields)):
194        " fields ::= field , fields "
195        return fields + [field]
196
197    def p_field_0(self, (type,)):
198        " field ::= Id "
199        return Field(type)
200
201    def p_field_1(self, (type, name)):
202        " field ::= Id Id "
203        return Field(type, name)
204
205    def p_field_2(self, (type, _, name)):
206        " field ::= Id * Id "
207        return Field(type, name, seq=True)
208
209    def p_field_3(self, (type, _, name)):
210        " field ::= Id ? Id "
211        return Field(type, name, opt=True)
212
213    def p_field_4(self, (type, _)):
214        " field ::= Id * "
215        return Field(type, seq=True)
216
217    def p_field_5(self, (type, _)):
218        " field ::= Id ? "
219        return Field(type, opt=True)
220
221builtin_types = ("identifier", "string", "int", "bool", "object")
222
223# below is a collection of classes to capture the AST of an AST :-)
224# not sure if any of the methods are useful yet, but I'm adding them
225# piecemeal as they seem helpful
226
227class AST(object):
228    pass # a marker class
229
230class Module(AST):
231    def __init__(self, name, dfns, version):
232        self.name = name
233        self.dfns = dfns
234        self.version = version
235        self.types = {} # maps type name to value (from dfns)
236        for type in dfns:
237            self.types[type.name.value] = type.value
238
239    def __repr__(self):
240        return "Module(%s, %s)" % (self.name, self.dfns)
241
242class Type(AST):
243    def __init__(self, name, value):
244        self.name = name
245        self.value = value
246
247    def __repr__(self):
248        return "Type(%s, %s)" % (self.name, self.value)
249
250class Constructor(AST):
251    def __init__(self, name, fields=None):
252        self.name = name
253        self.fields = fields or []
254
255    def __repr__(self):
256        return "Constructor(%s, %s)" % (self.name, self.fields)
257
258class Field(AST):
259    def __init__(self, type, name=None, seq=False, opt=False):
260        self.type = type
261        self.name = name
262        self.seq = seq
263        self.opt = opt
264
265    def __repr__(self):
266        if self.seq:
267            extra = ", seq=True"
268        elif self.opt:
269            extra = ", opt=True"
270        else:
271            extra = ""
272        if self.name is None:
273            return "Field(%s%s)" % (self.type, extra)
274        else:
275            return "Field(%s, %s%s)" % (self.type, self.name, extra)
276
277class Sum(AST):
278    def __init__(self, types, attributes=None):
279        self.types = types
280        self.attributes = attributes or []
281
282    def __repr__(self):
283        if self.attributes is None:
284            return "Sum(%s)" % self.types
285        else:
286            return "Sum(%s, %s)" % (self.types, self.attributes)
287
288class Product(AST):
289    def __init__(self, fields):
290        self.fields = fields
291
292    def __repr__(self):
293        return "Product(%s)" % self.fields
294
295class VisitorBase(object):
296
297    def __init__(self, skip=False):
298        self.cache = {}
299        self.skip = skip
300
301    def visit(self, object, *args):
302        meth = self._dispatch(object)
303        if meth is None:
304            return
305        try:
306            meth(object, *args)
307        except Exception, err:
308            print "Error visiting", repr(object)
309            print err
310            traceback.print_exc()
311            # XXX hack
312            if hasattr(self, 'file'):
313                self.file.flush()
314            os._exit(1)
315
316    def _dispatch(self, object):
317        assert isinstance(object, AST), repr(object)
318        klass = object.__class__
319        meth = self.cache.get(klass)
320        if meth is None:
321            methname = "visit" + klass.__name__
322            if self.skip:
323                meth = getattr(self, methname, None)
324            else:
325                meth = getattr(self, methname)
326            self.cache[klass] = meth
327        return meth
328
329class Check(VisitorBase):
330
331    def __init__(self):
332        super(Check, self).__init__(skip=True)
333        self.cons = {}
334        self.errors = 0
335        self.types = {}
336
337    def visitModule(self, mod):
338        for dfn in mod.dfns:
339            self.visit(dfn)
340
341    def visitType(self, type):
342        self.visit(type.value, str(type.name))
343
344    def visitSum(self, sum, name):
345        for t in sum.types:
346            self.visit(t, name)
347
348    def visitConstructor(self, cons, name):
349        key = str(cons.name)
350        conflict = self.cons.get(key)
351        if conflict is None:
352            self.cons[key] = name
353        else:
354            print "Redefinition of constructor %s" % key
355            print "Defined in %s and %s" % (conflict, name)
356            self.errors += 1
357        for f in cons.fields:
358            self.visit(f, key)
359
360    def visitField(self, field, name):
361        key = str(field.type)
362        l = self.types.setdefault(key, [])
363        l.append(name)
364
365    def visitProduct(self, prod, name):
366        for f in prod.fields:
367            self.visit(f, name)
368
369def check(mod):
370    v = Check()
371    v.visit(mod)
372
373    for t in v.types:
374        if t not in mod.types and not t in builtin_types:
375            v.errors += 1
376            uses = ", ".join(v.types[t])
377            print "Undefined type %s, used in %s" % (t, uses)
378
379    return not v.errors
380
381def parse(file):
382    scanner = ASDLScanner()
383    parser = ASDLParser()
384
385    buf = open(file).read()
386    tokens = scanner.tokenize(buf)
387    try:
388        return parser.parse(tokens)
389    except ASDLSyntaxError, err:
390        print err
391        lines = buf.split("\n")
392        print lines[err.lineno - 1] # lines starts at 0, files at 1
393
394if __name__ == "__main__":
395    import glob
396    import sys
397
398    if len(sys.argv) > 1:
399        files = sys.argv[1:]
400    else:
401        testdir = "tests"
402        files = glob.glob(testdir + "/*.asdl")
403
404    for file in files:
405        print file
406        mod = parse(file)
407        print "module", mod.name
408        print len(mod.dfns), "definitions"
409        if not check(mod):
410            print "Check failed"
411        else:
412            for dfn in mod.dfns:
413                print dfn.type
414