1# This Source Code Form is subject to the terms of the Mozilla Public
2# License, v. 2.0. If a copy of the MPL was not distributed with this
3# file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5from .util import hash_str
6
7
8NOT_NESTED = 1
9INSIDE_SYNC_NESTED = 2
10INSIDE_CPOW_NESTED = 3
11
12NORMAL_PRIORITY = 1
13INPUT_PRIORITY = 2
14VSYNC_PRIORITY = 3
15MEDIUMHIGH_PRIORITY = 4
16CONTROL_PRIORITY = 5
17
18NESTED_ATTR_MAP = {
19    "not": NOT_NESTED,
20    "inside_sync": INSIDE_SYNC_NESTED,
21    "inside_cpow": INSIDE_CPOW_NESTED,
22}
23
24PRIORITY_ATTR_MAP = {
25    "normal": NORMAL_PRIORITY,
26    "input": INPUT_PRIORITY,
27    "vsync": VSYNC_PRIORITY,
28    "mediumhigh": MEDIUMHIGH_PRIORITY,
29    "control": CONTROL_PRIORITY,
30}
31
32
33class Visitor:
34    def defaultVisit(self, node):
35        raise Exception(
36            "INTERNAL ERROR: no visitor for node type `%s'" % (node.__class__.__name__)
37        )
38
39    def visitTranslationUnit(self, tu):
40        for cxxInc in tu.cxxIncludes:
41            cxxInc.accept(self)
42        for inc in tu.includes:
43            inc.accept(self)
44        for su in tu.structsAndUnions:
45            su.accept(self)
46        for using in tu.builtinUsing:
47            using.accept(self)
48        for using in tu.using:
49            using.accept(self)
50        if tu.protocol:
51            tu.protocol.accept(self)
52
53    def visitCxxInclude(self, inc):
54        pass
55
56    def visitInclude(self, inc):
57        # Note: we don't visit the child AST here, because that needs delicate
58        # and pass-specific handling
59        pass
60
61    def visitStructDecl(self, struct):
62        for f in struct.fields:
63            f.accept(self)
64        for a in struct.attributes.values():
65            a.accept(self)
66
67    def visitStructField(self, field):
68        field.typespec.accept(self)
69
70    def visitUnionDecl(self, union):
71        for t in union.components:
72            t.accept(self)
73        for a in union.attributes.values():
74            a.accept(self)
75
76    def visitUsingStmt(self, using):
77        for a in using.attributes.values():
78            a.accept(self)
79
80    def visitProtocol(self, p):
81        for namespace in p.namespaces:
82            namespace.accept(self)
83        for mgr in p.managers:
84            mgr.accept(self)
85        for managed in p.managesStmts:
86            managed.accept(self)
87        for msgDecl in p.messageDecls:
88            msgDecl.accept(self)
89        for a in p.attributes.values():
90            a.accept(self)
91
92    def visitNamespace(self, ns):
93        pass
94
95    def visitManager(self, mgr):
96        pass
97
98    def visitManagesStmt(self, mgs):
99        pass
100
101    def visitMessageDecl(self, md):
102        for inParam in md.inParams:
103            inParam.accept(self)
104        for outParam in md.outParams:
105            outParam.accept(self)
106        for a in md.attributes.values():
107            a.accept(self)
108
109    def visitParam(self, decl):
110        for a in decl.attributes.values():
111            a.accept(self)
112
113    def visitTypeSpec(self, ts):
114        pass
115
116    def visitAttribute(self, a):
117        if isinstance(a.value, Node):
118            a.value.accept(self)
119
120    def visitStringLiteral(self, sl):
121        pass
122
123    def visitDecl(self, d):
124        for a in d.attributes.values():
125            a.accept(self)
126
127
128class Loc:
129    def __init__(self, filename="<??>", lineno=0):
130        assert filename
131        self.filename = filename
132        self.lineno = lineno
133
134    def __repr__(self):
135        return "%r:%r" % (self.filename, self.lineno)
136
137    def __str__(self):
138        return "%s:%s" % (self.filename, self.lineno)
139
140
141Loc.NONE = Loc(filename="<??>", lineno=0)
142
143
144class _struct:
145    pass
146
147
148class Node:
149    def __init__(self, loc=Loc.NONE):
150        self.loc = loc
151
152    def accept(self, visitor):
153        visit = getattr(visitor, "visit" + self.__class__.__name__, None)
154        if visit is None:
155            return getattr(visitor, "defaultVisit")(self)
156        return visit(self)
157
158    def addAttrs(self, attrsName):
159        if not hasattr(self, attrsName):
160            setattr(self, attrsName, _struct())
161
162
163class NamespacedNode(Node):
164    def __init__(self, loc=Loc.NONE, name=None):
165        Node.__init__(self, loc)
166        self.name = name
167        self.namespaces = []
168
169    def addOuterNamespace(self, namespace):
170        self.namespaces.insert(0, namespace)
171
172    def qname(self):
173        return QualifiedId(self.loc, self.name, [ns.name for ns in self.namespaces])
174
175
176class TranslationUnit(NamespacedNode):
177    def __init__(self, type, name):
178        NamespacedNode.__init__(self, name=name)
179        self.filetype = type
180        self.filename = None
181        self.cxxIncludes = []
182        self.includes = []
183        self.builtinUsing = []
184        self.using = []
185        self.structsAndUnions = []
186        self.protocol = None
187
188    def addCxxInclude(self, cxxInclude):
189        self.cxxIncludes.append(cxxInclude)
190
191    def addInclude(self, inc):
192        self.includes.append(inc)
193
194    def addStructDecl(self, struct):
195        self.structsAndUnions.append(struct)
196
197    def addUnionDecl(self, union):
198        self.structsAndUnions.append(union)
199
200    def addUsingStmt(self, using):
201        self.using.append(using)
202
203    def setProtocol(self, protocol):
204        self.protocol = protocol
205
206
207class CxxInclude(Node):
208    def __init__(self, loc, cxxFile):
209        Node.__init__(self, loc)
210        self.file = cxxFile
211
212
213class Include(Node):
214    def __init__(self, loc, type, name):
215        Node.__init__(self, loc)
216        suffix = "ipdl"
217        if type == "header":
218            suffix += "h"
219        self.file = "%s.%s" % (name, suffix)
220
221
222class UsingStmt(Node):
223    def __init__(
224        self,
225        loc,
226        cxxTypeSpec,
227        cxxHeader=None,
228        kind=None,
229        attributes={},
230    ):
231        Node.__init__(self, loc)
232        assert not isinstance(cxxTypeSpec, str)
233        assert cxxHeader is None or isinstance(cxxHeader, str)
234        assert kind is None or kind == "class" or kind == "struct"
235        self.type = cxxTypeSpec
236        self.header = cxxHeader
237        self.kind = kind
238        self.attributes = attributes
239
240    def canBeForwardDeclared(self):
241        return self.isClass() or self.isStruct()
242
243    def isClass(self):
244        return self.kind == "class"
245
246    def isStruct(self):
247        return self.kind == "struct"
248
249    def isRefcounted(self):
250        return "RefCounted" in self.attributes
251
252    def isSendMoveOnly(self):
253        moveonly = self.attributes.get("MoveOnly")
254        return moveonly and moveonly.value in (None, "send")
255
256    def isDataMoveOnly(self):
257        moveonly = self.attributes.get("MoveOnly")
258        return moveonly and moveonly.value in (None, "data")
259
260
261# "singletons"
262
263
264class PrettyPrinted:
265    @classmethod
266    def __hash__(cls):
267        return hash_str(cls.pretty)
268
269    @classmethod
270    def __str__(cls):
271        return cls.pretty
272
273
274class ASYNC(PrettyPrinted):
275    pretty = "async"
276
277
278class INTR(PrettyPrinted):
279    pretty = "intr"
280
281
282class SYNC(PrettyPrinted):
283    pretty = "sync"
284
285
286class INOUT(PrettyPrinted):
287    pretty = "inout"
288
289
290class IN(PrettyPrinted):
291    pretty = "in"
292
293
294class OUT(PrettyPrinted):
295    pretty = "out"
296
297
298class Namespace(Node):
299    def __init__(self, loc, namespace):
300        Node.__init__(self, loc)
301        self.name = namespace
302
303
304class Protocol(NamespacedNode):
305    def __init__(self, loc):
306        NamespacedNode.__init__(self, loc)
307        self.attributes = {}
308        self.sendSemantics = ASYNC
309        self.managers = []
310        self.managesStmts = []
311        self.messageDecls = []
312
313    def nestedUpTo(self):
314        if "NestedUpTo" not in self.attributes:
315            return NOT_NESTED
316
317        return NESTED_ATTR_MAP.get(self.attributes["NestedUpTo"].value, NOT_NESTED)
318
319    def implAttribute(self, side):
320        assert side in ("parent", "child")
321        attr = self.attributes.get(side.capitalize() + "Impl")
322        if attr is not None:
323            return attr.value
324        return None
325
326
327class StructField(Node):
328    def __init__(self, loc, type, name):
329        Node.__init__(self, loc)
330        self.typespec = type
331        self.name = name
332
333
334class StructDecl(NamespacedNode):
335    def __init__(self, loc, name, fields, attributes):
336        NamespacedNode.__init__(self, loc, name)
337        self.fields = fields
338        self.attributes = attributes
339        # A list of indices into `fields` for determining the order in
340        # which fields are laid out in memory.  We don't just reorder
341        # `fields` itself so as to keep the ordering reasonably stable
342        # for e.g. C++ constructors when new fields are added.
343        self.packed_field_ordering = []
344
345
346class UnionDecl(NamespacedNode):
347    def __init__(self, loc, name, components, attributes):
348        NamespacedNode.__init__(self, loc, name)
349        self.components = components
350        self.attributes = attributes
351
352
353class Manager(Node):
354    def __init__(self, loc, managerName):
355        Node.__init__(self, loc)
356        self.name = managerName
357
358
359class ManagesStmt(Node):
360    def __init__(self, loc, managedName):
361        Node.__init__(self, loc)
362        self.name = managedName
363
364
365class MessageDecl(Node):
366    def __init__(self, loc):
367        Node.__init__(self, loc)
368        self.name = None
369        self.attributes = {}
370        self.sendSemantics = ASYNC
371        self.direction = None
372        self.inParams = []
373        self.outParams = []
374
375    def addInParams(self, inParamsList):
376        self.inParams += inParamsList
377
378    def addOutParams(self, outParamsList):
379        self.outParams += outParamsList
380
381    def nested(self):
382        if "Nested" not in self.attributes:
383            return NOT_NESTED
384
385        return NESTED_ATTR_MAP.get(self.attributes["Nested"].value, NOT_NESTED)
386
387    def priority(self):
388        if "Priority" not in self.attributes:
389            return NORMAL_PRIORITY
390
391        return PRIORITY_ATTR_MAP.get(self.attributes["Priority"].value, NORMAL_PRIORITY)
392
393
394class Param(Node):
395    def __init__(self, loc, typespec, name, attributes={}):
396        Node.__init__(self, loc)
397        self.name = name
398        self.typespec = typespec
399        self.attributes = attributes
400
401
402class TypeSpec(Node):
403    def __init__(self, loc, spec):
404        Node.__init__(self, loc)
405        self.spec = spec  # QualifiedId
406        self.array = False  # bool
407        self.maybe = False  # bool
408        self.nullable = False  # bool
409        self.uniqueptr = False  # bool
410
411    def basename(self):
412        return self.spec.baseid
413
414    def __str__(self):
415        return str(self.spec)
416
417
418class Attribute(Node):
419    def __init__(self, loc, name, value):
420        Node.__init__(self, loc)
421        self.name = name
422        self.value = value
423
424
425class StringLiteral(Node):
426    def __init__(self, loc, value):
427        Node.__init__(self, loc)
428        self.value = value
429
430    def __str__(self):
431        return '"%s"' % self.value
432
433
434class QualifiedId:  # FIXME inherit from node?
435    def __init__(self, loc, baseid, quals=[]):
436        assert isinstance(baseid, str)
437        for qual in quals:
438            assert isinstance(qual, str)
439
440        self.loc = loc
441        self.baseid = baseid
442        self.quals = quals
443
444    def qualify(self, id):
445        self.quals.append(self.baseid)
446        self.baseid = id
447
448    def __str__(self):
449        if 0 == len(self.quals):
450            return self.baseid
451        return "::".join(self.quals) + "::" + self.baseid
452
453
454# added by type checking passes
455
456
457class Decl(Node):
458    def __init__(self, loc):
459        Node.__init__(self, loc)
460        self.progname = None  # what the programmer typed, if relevant
461        self.shortname = None  # shortest way to refer to this decl
462        self.fullname = None  # full way to refer to this decl
463        self.loc = loc
464        self.type = None
465        self.scope = None
466        self.attributes = {}
467