1# This Source Code Form is subject to the terms of the Mozilla Public
2# License, v. 2.0. If a copy of the MPL was not distributed with this
3# file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5import re
6from copy import deepcopy
7from collections import OrderedDict
8import itertools
9
10import ipdl.ast
11import ipdl.builtin
12from ipdl.cxx.ast import *
13from ipdl.cxx.code import *
14from ipdl.direct_call import VIRTUAL_CALL_CLASSES, DIRECT_CALL_OVERRIDES
15from ipdl.type import ActorType, UnionType, TypeVisitor, builtinHeaderIncludes
16from ipdl.util import hash_str
17
18
19# -----------------------------------------------------------------------------
20# "Public" interface to lowering
21##
22
23
24class LowerToCxx:
25    def lower(self, tu, segmentcapacitydict):
26        """returns |[ header: File ], [ cpp : File ]| representing the
27        lowered form of |tu|"""
28        # annotate the AST with IPDL/C++ IR-type stuff used later
29        tu.accept(_DecorateWithCxxStuff())
30
31        # Any modifications to the filename scheme here need corresponding
32        # modifications in the ipdl.py driver script.
33        name = tu.name
34        pheader, pcpp = File(name + ".h"), File(name + ".cpp")
35
36        _GenerateProtocolCode().lower(tu, pheader, pcpp, segmentcapacitydict)
37        headers = [pheader]
38        cpps = [pcpp]
39
40        if tu.protocol:
41            pname = tu.protocol.name
42
43            parentheader, parentcpp = (
44                File(pname + "Parent.h"),
45                File(pname + "Parent.cpp"),
46            )
47            _GenerateProtocolParentCode().lower(
48                tu, pname + "Parent", parentheader, parentcpp
49            )
50
51            childheader, childcpp = File(pname + "Child.h"), File(pname + "Child.cpp")
52            _GenerateProtocolChildCode().lower(
53                tu, pname + "Child", childheader, childcpp
54            )
55
56            headers += [parentheader, childheader]
57            cpps += [parentcpp, childcpp]
58
59        return headers, cpps
60
61
62# -----------------------------------------------------------------------------
63# Helper code
64##
65
66
67def hashfunc(value):
68    h = hash_str(value) % 2 ** 32
69    if h < 0:
70        h += 2 ** 32
71    return h
72
73
74_NULL_ACTOR_ID = ExprLiteral.ZERO
75_FREED_ACTOR_ID = ExprLiteral.ONE
76
77_DISCLAIMER = Whitespace(
78    """//
79// Automatically generated by ipdlc.
80// Edit at your own risk
81//
82
83"""
84)
85
86
87class _struct:
88    pass
89
90
91def _namespacedHeaderName(name, namespaces):
92    pfx = "/".join([ns.name for ns in namespaces])
93    if pfx:
94        return pfx + "/" + name
95    else:
96        return name
97
98
99def _ipdlhHeaderName(tu):
100    assert tu.filetype == "header"
101    return _namespacedHeaderName(tu.name, tu.namespaces)
102
103
104def _protocolHeaderName(p, side=""):
105    if side:
106        side = side.title()
107    base = p.name + side
108    return _namespacedHeaderName(base, p.namespaces)
109
110
111def _includeGuardMacroName(headerfile):
112    return re.sub(r"[./]", "_", headerfile.name)
113
114
115def _includeGuardStart(headerfile):
116    guard = _includeGuardMacroName(headerfile)
117    return [CppDirective("ifndef", guard), CppDirective("define", guard)]
118
119
120def _includeGuardEnd(headerfile):
121    guard = _includeGuardMacroName(headerfile)
122    return [CppDirective("endif", "// ifndef " + guard)]
123
124
125def _messageStartName(ptype):
126    return ptype.name() + "MsgStart"
127
128
129def _protocolId(ptype):
130    return ExprVar(_messageStartName(ptype))
131
132
133def _protocolIdType():
134    return Type.INT32
135
136
137def _actorName(pname, side):
138    """|pname| is the protocol name. |side| is 'Parent' or 'Child'."""
139    tag = side
140    if not tag[0].isupper():
141        tag = side.title()
142    return pname + tag
143
144
145def _actorIdType():
146    return Type.INT32
147
148
149def _actorTypeTagType():
150    return Type.INT32
151
152
153def _actorId(actor=None):
154    if actor is not None:
155        return ExprCall(ExprSelect(actor, "->", "Id"))
156    return ExprCall(ExprVar("Id"))
157
158
159def _actorHId(actorhandle):
160    return ExprSelect(actorhandle, ".", "mId")
161
162
163def _backstagePass():
164    return ExprCall(ExprVar("mozilla::ipc::PrivateIPDLInterface"))
165
166
167def _iterType(ptr):
168    return Type("PickleIterator", ptr=ptr)
169
170
171def _deleteId():
172    return ExprVar("Msg___delete____ID")
173
174
175def _deleteReplyId():
176    return ExprVar("Reply___delete____ID")
177
178
179def _lookupListener(idexpr):
180    return ExprCall(ExprVar("Lookup"), args=[idexpr])
181
182
183def _makeForwardDeclForQClass(clsname, quals, cls=True, struct=False):
184    fd = ForwardDecl(clsname, cls=cls, struct=struct)
185    if 0 == len(quals):
186        return fd
187
188    outerns = Namespace(quals[0])
189    innerns = outerns
190    for ns in quals[1:]:
191        tmpns = Namespace(ns)
192        innerns.addstmt(tmpns)
193        innerns = tmpns
194
195    innerns.addstmt(fd)
196    return outerns
197
198
199def _makeForwardDeclForActor(ptype, side):
200    return _makeForwardDeclForQClass(
201        _actorName(ptype.qname.baseid, side), ptype.qname.quals
202    )
203
204
205def _makeForwardDecl(type):
206    return _makeForwardDeclForQClass(type.name(), type.qname.quals)
207
208
209def _putInNamespaces(cxxthing, namespaces):
210    """|namespaces| is in order [ outer, ..., inner ]"""
211    if 0 == len(namespaces):
212        return cxxthing
213
214    outerns = Namespace(namespaces[0].name)
215    innerns = outerns
216    for ns in namespaces[1:]:
217        newns = Namespace(ns.name)
218        innerns.addstmt(newns)
219        innerns = newns
220    innerns.addstmt(cxxthing)
221    return outerns
222
223
224def _sendPrefix(msgtype):
225    """Prefix of the name of the C++ method that sends |msgtype|."""
226    if msgtype.isInterrupt():
227        return "Call"
228    return "Send"
229
230
231def _recvPrefix(msgtype):
232    """Prefix of the name of the C++ method that handles |msgtype|."""
233    if msgtype.isInterrupt():
234        return "Answer"
235    return "Recv"
236
237
238def _flatTypeName(ipdltype):
239    """Return a 'flattened' IPDL type name that can be used as an
240    identifier.
241    E.g., |Foo[]| --> |ArrayOfFoo|."""
242    # NB: this logic depends heavily on what IPDL types are allowed to
243    # be constructed; e.g., Foo[][] is disallowed.  needs to be kept in
244    # sync with grammar.
245    if ipdltype.isIPDL() and ipdltype.isArray():
246        return "ArrayOf" + ipdltype.basetype.name()
247    if ipdltype.isIPDL() and ipdltype.isMaybe():
248        return "Maybe" + ipdltype.basetype.name()
249    return ipdltype.name()
250
251
252def _hasVisibleActor(ipdltype):
253    """Return true iff a C++ decl of |ipdltype| would have an Actor* type.
254    For example: |Actor[]| would turn into |Array<ActorParent*>|, so this
255    function would return true for |Actor[]|."""
256    return ipdltype.isIPDL() and (
257        ipdltype.isActor()
258        or (ipdltype.hasBaseType() and _hasVisibleActor(ipdltype.basetype))
259    )
260
261
262def _abortIfFalse(cond, msg):
263    return StmtExpr(
264        ExprCall(ExprVar("MOZ_RELEASE_ASSERT"), [cond, ExprLiteral.String(msg)])
265    )
266
267
268def _refptr(T):
269    return Type("RefPtr", T=T)
270
271
272def _uniqueptr(T):
273    return Type("UniquePtr", T=T)
274
275
276def _alreadyaddrefed(T):
277    return Type("already_AddRefed", T=T)
278
279
280def _tuple(types, const=False, ref=False):
281    return Type("Tuple", T=types, const=const, ref=ref)
282
283
284def _promise(resolvetype, rejecttype, tail, resolver=False):
285    inner = Type("Private") if resolver else None
286    return Type("MozPromise", T=[resolvetype, rejecttype, tail], inner=inner)
287
288
289def _makePromise(returns, side, resolver=False):
290    if len(returns) > 1:
291        resolvetype = _tuple([d.bareType(side) for d in returns])
292    else:
293        resolvetype = returns[0].bareType(side)
294
295    # MozPromise is purposefully made to be exclusive only. Really, we mean it.
296    return _promise(
297        resolvetype, _ResponseRejectReason.Type(), ExprLiteral.TRUE, resolver=resolver
298    )
299
300
301def _resolveType(returns, side):
302    if len(returns) > 1:
303        return _tuple([d.moveType(side) for d in returns])
304    return returns[0].moveType(side)
305
306
307def _makeResolver(returns, side):
308    return TypeFunction([Decl(_resolveType(returns, side), "")])
309
310
311def _cxxArrayType(basetype, const=False, ref=False):
312    return Type("nsTArray", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False)
313
314
315def _cxxMaybeType(basetype, const=False, ref=False):
316    return Type(
317        "mozilla::Maybe",
318        T=basetype,
319        const=const,
320        ref=ref,
321        hasimplicitcopyctor=basetype.hasimplicitcopyctor,
322    )
323
324
325def _cxxManagedContainerType(basetype, const=False, ref=False):
326    return Type(
327        "ManagedContainer", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False
328    )
329
330
331def _cxxLifecycleProxyType(ptr=False):
332    return Type("mozilla::ipc::ActorLifecycleProxy", ptr=ptr)
333
334
335def _otherSide(side):
336    if side == "child":
337        return "parent"
338    if side == "parent":
339        return "child"
340    assert 0
341
342
343def _ifLogging(topLevelProtocol, stmts):
344    return StmtCode(
345        """
346        if (mozilla::ipc::LoggingEnabledFor(${proto})) {
347            $*{stmts}
348        }
349        """,
350        proto=topLevelProtocol,
351        stmts=stmts,
352    )
353
354
355# XXX we need to remove these and install proper error handling
356
357
358def _printErrorMessage(msg):
359    if isinstance(msg, str):
360        msg = ExprLiteral.String(msg)
361    return StmtExpr(ExprCall(ExprVar("NS_ERROR"), args=[msg]))
362
363
364def _protocolErrorBreakpoint(msg):
365    if isinstance(msg, str):
366        msg = ExprLiteral.String(msg)
367    return StmtExpr(
368        ExprCall(ExprVar("mozilla::ipc::ProtocolErrorBreakpoint"), args=[msg])
369    )
370
371
372def _printWarningMessage(msg):
373    if isinstance(msg, str):
374        msg = ExprLiteral.String(msg)
375    return StmtExpr(ExprCall(ExprVar("NS_WARNING"), args=[msg]))
376
377
378def _fatalError(msg):
379    return StmtExpr(ExprCall(ExprVar("FatalError"), args=[ExprLiteral.String(msg)]))
380
381
382def _logicError(msg):
383    return StmtExpr(
384        ExprCall(ExprVar("mozilla::ipc::LogicError"), args=[ExprLiteral.String(msg)])
385    )
386
387
388def _sentinelReadError(classname):
389    return StmtExpr(
390        ExprCall(
391            ExprVar("mozilla::ipc::SentinelReadError"),
392            args=[ExprLiteral.String(classname)],
393        )
394    )
395
396
397# Results that IPDL-generated code returns back to *Channel code.
398# Users never see these
399
400
401class _Result:
402    @staticmethod
403    def Type():
404        return Type("Result")
405
406    Processed = ExprVar("MsgProcessed")
407    NotKnown = ExprVar("MsgNotKnown")
408    NotAllowed = ExprVar("MsgNotAllowed")
409    PayloadError = ExprVar("MsgPayloadError")
410    ProcessingError = ExprVar("MsgProcessingError")
411    RouteError = ExprVar("MsgRouteError")
412    ValuError = ExprVar("MsgValueError")  # [sic]
413
414
415# these |errfn*| are functions that generate code to be executed on an
416# error, such as "bad actor ID".  each is given a Python string
417# containing a description of the error
418
419# used in user-facing Send*() methods
420
421
422def errfnSend(msg, errcode=ExprLiteral.FALSE):
423    return [_fatalError(msg), StmtReturn(errcode)]
424
425
426def errfnSendCtor(msg):
427    return errfnSend(msg, errcode=ExprLiteral.NULL)
428
429
430# TODO should this error handling be strengthened for dtors?
431
432
433def errfnSendDtor(msg):
434    return [_printErrorMessage(msg), StmtReturn.FALSE]
435
436
437# used in |OnMessage*()| handlers that hand in-messages off to Recv*()
438# interface methods
439
440
441def errfnRecv(msg, errcode=_Result.ValuError):
442    return [_fatalError(msg), StmtReturn(errcode)]
443
444
445def errfnSentinel(rvalue=ExprLiteral.FALSE):
446    def inner(msg):
447        return [_sentinelReadError(msg), StmtReturn(rvalue)]
448
449    return inner
450
451
452def _destroyMethod():
453    return ExprVar("ActorDestroy")
454
455
456def errfnUnreachable(msg):
457    return [_logicError(msg)]
458
459
460class _DestroyReason:
461    @staticmethod
462    def Type():
463        return Type("ActorDestroyReason")
464
465    Deletion = ExprVar("Deletion")
466    AncestorDeletion = ExprVar("AncestorDeletion")
467    NormalShutdown = ExprVar("NormalShutdown")
468    AbnormalShutdown = ExprVar("AbnormalShutdown")
469    FailedConstructor = ExprVar("FailedConstructor")
470
471
472class _ResponseRejectReason:
473    @staticmethod
474    def Type():
475        return Type("ResponseRejectReason")
476
477    SendError = ExprVar("ResponseRejectReason::SendError")
478    ChannelClosed = ExprVar("ResponseRejectReason::ChannelClosed")
479    HandlerRejected = ExprVar("ResponseRejectReason::HandlerRejected")
480    ActorDestroyed = ExprVar("ResponseRejectReason::ActorDestroyed")
481
482
483# -----------------------------------------------------------------------------
484# Intermediate representation (IR) nodes used during lowering
485
486
487class _ConvertToCxxType(TypeVisitor):
488    def __init__(self, side, fq):
489        self.side = side
490        self.fq = fq
491
492    def typename(self, thing):
493        if self.fq:
494            return thing.fullname()
495        return thing.name()
496
497    def visitImportedCxxType(self, t):
498        cxxtype = Type(self.typename(t))
499        if t.isRefcounted():
500            cxxtype = _refptr(cxxtype)
501        return cxxtype
502
503    def visitActorType(self, a):
504        return Type(_actorName(self.typename(a.protocol), self.side), ptr=True)
505
506    def visitStructType(self, s):
507        return Type(self.typename(s))
508
509    def visitUnionType(self, u):
510        return Type(self.typename(u))
511
512    def visitArrayType(self, a):
513        basecxxtype = a.basetype.accept(self)
514        return _cxxArrayType(basecxxtype)
515
516    def visitMaybeType(self, m):
517        basecxxtype = m.basetype.accept(self)
518        return _cxxMaybeType(basecxxtype)
519
520    def visitShmemType(self, s):
521        return Type(self.typename(s))
522
523    def visitByteBufType(self, s):
524        return Type(self.typename(s))
525
526    def visitFDType(self, s):
527        return Type(self.typename(s))
528
529    def visitEndpointType(self, s):
530        return Type(self.typename(s))
531
532    def visitManagedEndpointType(self, s):
533        return Type(self.typename(s))
534
535    def visitUniquePtrType(self, s):
536        return Type(self.typename(s))
537
538    def visitProtocolType(self, p):
539        assert 0
540
541    def visitMessageType(self, m):
542        assert 0
543
544    def visitVoidType(self, v):
545        assert 0
546
547
548def _cxxBareType(ipdltype, side, fq=False):
549    return ipdltype.accept(_ConvertToCxxType(side, fq))
550
551
552def _cxxRefType(ipdltype, side):
553    t = _cxxBareType(ipdltype, side)
554    t.ref = True
555    return t
556
557
558def _cxxConstRefType(ipdltype, side):
559    t = _cxxBareType(ipdltype, side)
560    if ipdltype.isIPDL() and ipdltype.isActor():
561        return t
562    if ipdltype.isIPDL() and ipdltype.isShmem():
563        t.ref = True
564        return t
565    if ipdltype.isIPDL() and ipdltype.isByteBuf():
566        t.ref = True
567        return t
568    if ipdltype.isIPDL() and ipdltype.hasBaseType():
569        # Keep same constness as inner type.
570        inner = _cxxConstRefType(ipdltype.basetype, side)
571        t.const = inner.const or not inner.ref
572        t.ref = True
573        return t
574    if ipdltype.isCxx() and ipdltype.isMoveonly():
575        t.const = True
576        t.ref = True
577        return t
578    if ipdltype.isCxx() and ipdltype.isRefcounted():
579        # Use T* instead of const RefPtr<T>&
580        t = t.T
581        t.ptr = True
582        return t
583    if ipdltype.isUniquePtr():
584        t.ref = True
585        return t
586    t.const = True
587    t.ref = True
588    return t
589
590
591def _cxxTypeCanMoveSend(ipdltype):
592    return ipdltype.isUniquePtr()
593
594
595def _cxxTypeNeedsMove(ipdltype):
596    if _cxxTypeNeedsMoveForSend(ipdltype):
597        return True
598
599    if ipdltype.isIPDL():
600        return ipdltype.isArray()
601
602    return False
603
604
605def _cxxTypeNeedsMoveForSend(ipdltype):
606    if ipdltype.isUniquePtr():
607        return True
608
609    if ipdltype.isCxx():
610        return ipdltype.isMoveonly()
611
612    if ipdltype.isIPDL():
613        if ipdltype.hasBaseType():
614            return _cxxTypeNeedsMove(ipdltype.basetype)
615        return (
616            ipdltype.isShmem()
617            or ipdltype.isByteBuf()
618            or ipdltype.isEndpoint()
619            or ipdltype.isManagedEndpoint()
620        )
621
622    return False
623
624
625# FIXME Bug 1547019 This should be the same as _cxxTypeNeedsMoveForSend, but
626#                   a lot of existing code needs to be updated and fixed before
627#                   we can do that.
628def _cxxTypeCanOnlyMove(ipdltype, visited=None):
629    if visited is None:
630        visited = set()
631
632    visited.add(ipdltype)
633
634    if ipdltype.isCxx():
635        return ipdltype.isMoveonly()
636
637    if ipdltype.isIPDL():
638        if ipdltype.isMaybe() or ipdltype.isArray():
639            return _cxxTypeCanOnlyMove(ipdltype.basetype, visited)
640        if ipdltype.isStruct() or ipdltype.isUnion():
641            return any(
642                _cxxTypeCanOnlyMove(t, visited)
643                for t in ipdltype.itercomponents()
644                if t not in visited
645            )
646        return ipdltype.isManagedEndpoint()
647
648    return False
649
650
651def _cxxTypeCanMove(ipdltype):
652    return not (ipdltype.isIPDL() and ipdltype.isActor())
653
654
655def _cxxMoveRefType(ipdltype, side):
656    t = _cxxBareType(ipdltype, side)
657    if _cxxTypeNeedsMove(ipdltype):
658        t.rvalref = True
659        return t
660    return _cxxConstRefType(ipdltype, side)
661
662
663def _cxxForceMoveRefType(ipdltype, side):
664    assert _cxxTypeCanMove(ipdltype)
665    t = _cxxBareType(ipdltype, side)
666    t.rvalref = True
667    return t
668
669
670def _cxxPtrToType(ipdltype, side):
671    t = _cxxBareType(ipdltype, side)
672    if ipdltype.isIPDL() and ipdltype.isActor():
673        t.ptr = False
674        t.ptrptr = True
675        return t
676    t.ptr = True
677    return t
678
679
680def _cxxConstPtrToType(ipdltype, side):
681    t = _cxxBareType(ipdltype, side)
682    if ipdltype.isIPDL() and ipdltype.isActor():
683        t.ptr = False
684        t.ptrconstptr = True
685        return t
686    t.const = True
687    t.ptr = True
688    return t
689
690
691def _allocMethod(ptype, side):
692    return "Alloc" + ptype.name() + side.title()
693
694
695def _deallocMethod(ptype, side):
696    return "Dealloc" + ptype.name() + side.title()
697
698
699##
700# A _HybridDecl straddles IPDL and C++ decls.  It knows which C++
701# types correspond to which IPDL types, and it also knows how
702# serialize and deserialize "special" IPDL C++ types.
703##
704
705
706class _HybridDecl:
707    """A hybrid decl stores both an IPDL type and all the C++ type
708    info needed by later passes, along with a basic name for the decl."""
709
710    def __init__(self, ipdltype, name, attributes={}):
711        self.ipdltype = ipdltype
712        self.name = name
713        self.attributes = attributes
714
715    def var(self):
716        return ExprVar(self.name)
717
718    def bareType(self, side, fq=False):
719        """Return this decl's unqualified C++ type."""
720        return _cxxBareType(self.ipdltype, side, fq=fq)
721
722    def refType(self, side):
723        """Return this decl's C++ type as a 'reference' type, which is not
724        necessarily a C++ reference."""
725        return _cxxRefType(self.ipdltype, side)
726
727    def constRefType(self, side):
728        """Return this decl's C++ type as a const, 'reference' type."""
729        return _cxxConstRefType(self.ipdltype, side)
730
731    def rvalueRefType(self, side):
732        """Return this decl's C++ type as an r-value 'reference' type."""
733        return _cxxMoveRefType(self.ipdltype, side)
734
735    def ptrToType(self, side):
736        return _cxxPtrToType(self.ipdltype, side)
737
738    def constPtrToType(self, side):
739        return _cxxConstPtrToType(self.ipdltype, side)
740
741    def inType(self, side):
742        """Return this decl's C++ Type with inparam semantics."""
743        if self.ipdltype.isIPDL() and self.ipdltype.isActor():
744            return self.bareType(side)
745        elif _cxxTypeNeedsMoveForSend(self.ipdltype):
746            return self.rvalueRefType(side)
747        return self.constRefType(side)
748
749    def moveType(self, side):
750        """Return this decl's C++ Type with move semantics."""
751        if self.ipdltype.isIPDL() and self.ipdltype.isActor():
752            return self.bareType(side)
753        return self.rvalueRefType(side)
754
755    def outType(self, side):
756        """Return this decl's C++ Type with outparam semantics."""
757        t = self.bareType(side)
758        if self.ipdltype.isIPDL() and self.ipdltype.isActor():
759            t.ptr = False
760            t.ptrptr = True
761            return t
762        t.ptr = True
763        return t
764
765    def forceMoveType(self, side):
766        """Return this decl's C++ Type with forced move semantics."""
767        assert _cxxTypeCanMove(self.ipdltype)
768        return _cxxForceMoveRefType(self.ipdltype, side)
769
770
771# --------------------------------------------------
772
773
774class HasFQName:
775    def fqClassName(self):
776        return self.decl.type.fullname()
777
778
779class _CompoundTypeComponent(_HybridDecl):
780    def __init__(self, ipdltype, name, side, ct):
781        _HybridDecl.__init__(self, ipdltype, name)
782        self.side = side
783        self.special = _hasVisibleActor(ipdltype)
784
785    # @override the following methods to pass |self.side| instead of
786    # forcing the caller to remember which side we're declared to
787    # represent.
788    def bareType(self, side=None, fq=False):
789        return _HybridDecl.bareType(self, self.side, fq=fq)
790
791    def refType(self, side=None):
792        return _HybridDecl.refType(self, self.side)
793
794    def constRefType(self, side=None):
795        return _HybridDecl.constRefType(self, self.side)
796
797    def ptrToType(self, side=None):
798        return _HybridDecl.ptrToType(self, self.side)
799
800    def constPtrToType(self, side=None):
801        return _HybridDecl.constPtrToType(self, self.side)
802
803    def inType(self, side=None):
804        return _HybridDecl.inType(self, self.side)
805
806    def forceMoveType(self, side=None):
807        return _HybridDecl.forceMoveType(self, self.side)
808
809
810class StructDecl(ipdl.ast.StructDecl, HasFQName):
811    def fields_ipdl_order(self):
812        for f in self.fields:
813            yield f
814
815    def fields_member_order(self):
816        assert len(self.packed_field_order) == len(self.fields)
817
818        for i in self.packed_field_order:
819            yield self.fields[i]
820
821    @staticmethod
822    def upgrade(structDecl):
823        assert isinstance(structDecl, ipdl.ast.StructDecl)
824        structDecl.__class__ = StructDecl
825
826
827class _StructField(_CompoundTypeComponent):
828    def __init__(self, ipdltype, name, sd, side=None):
829        self.basename = name
830        fname = name
831        special = _hasVisibleActor(ipdltype)
832        if special:
833            fname += side.title()
834
835        _CompoundTypeComponent.__init__(self, ipdltype, fname, side, sd)
836
837    def getMethod(self, thisexpr=None, sel="."):
838        meth = self.var()
839        if thisexpr is not None:
840            return ExprSelect(thisexpr, sel, meth.name)
841        return meth
842
843    def refExpr(self, thisexpr=None):
844        ref = self.memberVar()
845        if thisexpr is not None:
846            ref = ExprSelect(thisexpr, ".", ref.name)
847        return ref
848
849    def constRefExpr(self, thisexpr=None):
850        # sigh, gross hack
851        refexpr = self.refExpr(thisexpr)
852        if "Shmem" == self.ipdltype.name():
853            refexpr = ExprCast(refexpr, Type("Shmem", ref=True), const=True)
854        if "ByteBuf" == self.ipdltype.name():
855            refexpr = ExprCast(refexpr, Type("ByteBuf", ref=True), const=True)
856        if "FileDescriptor" == self.ipdltype.name():
857            refexpr = ExprCast(refexpr, Type("FileDescriptor", ref=True), const=True)
858        return refexpr
859
860    def argVar(self):
861        return ExprVar("_" + self.name)
862
863    def memberVar(self):
864        return ExprVar(self.name + "_")
865
866
867class UnionDecl(ipdl.ast.UnionDecl, HasFQName):
868    def callType(self, var=None):
869        func = ExprVar("type")
870        if var is not None:
871            func = ExprSelect(var, ".", func.name)
872        return ExprCall(func)
873
874    @staticmethod
875    def upgrade(unionDecl):
876        assert isinstance(unionDecl, ipdl.ast.UnionDecl)
877        unionDecl.__class__ = UnionDecl
878
879
880class _UnionMember(_CompoundTypeComponent):
881    """Not in the AFL sense, but rather a member (e.g. |int;|) of an
882    IPDL union type."""
883
884    def __init__(self, ipdltype, ud, side=None, other=None):
885        flatname = _flatTypeName(ipdltype)
886        special = _hasVisibleActor(ipdltype)
887        if special:
888            flatname += side.title()
889
890        _CompoundTypeComponent.__init__(self, ipdltype, "V" + flatname, side, ud)
891        self.flattypename = flatname
892        if special:
893            if other is not None:
894                self.other = other
895            else:
896                self.other = _UnionMember(ipdltype, ud, _otherSide(side), self)
897
898        # To create a finite object with a mutually recursive type, a union must
899        # be present somewhere in the recursive loop. Because of that we only
900        # need to care about introducing indirections inside unions.
901        self.recursive = ud.decl.type.mutuallyRecursiveWith(ipdltype)
902
903    def enum(self):
904        return "T" + self.flattypename
905
906    def enumvar(self):
907        return ExprVar(self.enum())
908
909    def internalType(self):
910        if self.recursive:
911            return self.ptrToType()
912        else:
913            return self.bareType()
914
915    def unionType(self):
916        """Type used for storage in generated C union decl."""
917        if self.recursive:
918            return self.ptrToType()
919        else:
920            return Type("mozilla::AlignedStorage2", T=self.internalType())
921
922    def unionValue(self):
923        # NB: knows that Union's storage C union is named |mValue|
924        return ExprSelect(ExprVar("mValue"), ".", self.name)
925
926    def typedef(self):
927        return self.flattypename + "__tdef"
928
929    def callGetConstPtr(self):
930        """Return an expression of type self.constptrToSelfType()"""
931        return ExprCall(ExprVar(self.getConstPtrName()))
932
933    def callGetPtr(self):
934        """Return an expression of type self.ptrToSelfType()"""
935        return ExprCall(ExprVar(self.getPtrName()))
936
937    def callOperatorEq(self, rhs):
938        if self.ipdltype.isIPDL() and self.ipdltype.isActor():
939            rhs = ExprCast(rhs, self.bareType(), const=True)
940        elif (
941            self.ipdltype.isIPDL()
942            and self.ipdltype.isArray()
943            and not isinstance(rhs, ExprMove)
944        ):
945            rhs = ExprCall(ExprSelect(rhs, ".", "Clone"), args=[])
946        return ExprAssn(ExprDeref(self.callGetPtr()), rhs)
947
948    def callCtor(self, expr=None):
949        assert not isinstance(expr, list)
950
951        if expr is None:
952            args = None
953        elif self.ipdltype.isIPDL() and self.ipdltype.isActor():
954            args = [ExprCast(expr, self.bareType(), const=True)]
955        elif (
956            self.ipdltype.isIPDL()
957            and self.ipdltype.isArray()
958            and not isinstance(expr, ExprMove)
959        ):
960            args = [ExprCall(ExprSelect(expr, ".", "Clone"), args=[])]
961        else:
962            args = [expr]
963
964        if self.recursive:
965            return ExprAssn(
966                self.callGetPtr(), ExprNew(self.bareType(self.side), args=args)
967            )
968        else:
969            return ExprNew(
970                self.bareType(self.side),
971                args=args,
972                newargs=[ExprVar("mozilla::KnownNotNull"), self.callGetPtr()],
973            )
974
975    def callDtor(self):
976        if self.recursive:
977            return ExprDelete(self.callGetPtr())
978        else:
979            return ExprCall(ExprSelect(self.callGetPtr(), "->", "~" + self.typedef()))
980
981    def getTypeName(self):
982        return "get_" + self.flattypename
983
984    def getConstTypeName(self):
985        return "get_" + self.flattypename
986
987    def getOtherTypeName(self):
988        return "get_" + self.otherflattypename
989
990    def getPtrName(self):
991        return "ptr_" + self.flattypename
992
993    def getConstPtrName(self):
994        return "constptr_" + self.flattypename
995
996    def ptrToSelfExpr(self):
997        """|*ptrToSelfExpr()| has type |self.bareType()|"""
998        v = self.unionValue()
999        if self.recursive:
1000            return v
1001        else:
1002            return ExprCall(ExprSelect(v, ".", "addr"))
1003
1004    def constptrToSelfExpr(self):
1005        """|*constptrToSelfExpr()| has type |self.constType()|"""
1006        v = self.unionValue()
1007        if self.recursive:
1008            return v
1009        return ExprCall(ExprSelect(v, ".", "addr"))
1010
1011    def ptrToInternalType(self):
1012        t = self.ptrToType()
1013        if self.recursive:
1014            t.ref = True
1015        return t
1016
1017    def defaultValue(self, fq=False):
1018        # Use the default constructor for any class that does not have an
1019        # implicit copy constructor.
1020        if not self.bareType().hasimplicitcopyctor:
1021            return None
1022
1023        if self.ipdltype.isIPDL() and self.ipdltype.isActor():
1024            return ExprLiteral.NULL
1025        # XXX sneaky here, maybe need ExprCtor()?
1026        return ExprCall(self.bareType(fq=fq))
1027
1028    def getConstValue(self):
1029        v = ExprDeref(self.callGetConstPtr())
1030        # sigh
1031        if "ByteBuf" == self.ipdltype.name():
1032            v = ExprCast(v, Type("ByteBuf", ref=True), const=True)
1033        if "Shmem" == self.ipdltype.name():
1034            v = ExprCast(v, Type("Shmem", ref=True), const=True)
1035        if "FileDescriptor" == self.ipdltype.name():
1036            v = ExprCast(v, Type("FileDescriptor", ref=True), const=True)
1037        return v
1038
1039
1040# --------------------------------------------------
1041
1042
1043class MessageDecl(ipdl.ast.MessageDecl):
1044    def baseName(self):
1045        return self.name
1046
1047    def recvMethod(self):
1048        name = _recvPrefix(self.decl.type) + self.baseName()
1049        if self.decl.type.isCtor():
1050            name += "Constructor"
1051        return name
1052
1053    def sendMethod(self):
1054        name = _sendPrefix(self.decl.type) + self.baseName()
1055        if self.decl.type.isCtor():
1056            name += "Constructor"
1057        return name
1058
1059    def hasReply(self):
1060        return (
1061            self.decl.type.hasReply()
1062            or self.decl.type.isCtor()
1063            or self.decl.type.isDtor()
1064        )
1065
1066    def hasAsyncReturns(self):
1067        return self.decl.type.isAsync() and self.returns
1068
1069    def msgCtorFunc(self):
1070        return "Msg_%s" % (self.decl.progname)
1071
1072    def prettyMsgName(self, pfx=""):
1073        return pfx + self.msgCtorFunc()
1074
1075    def pqMsgCtorFunc(self):
1076        return "%s::%s" % (self.namespace, self.msgCtorFunc())
1077
1078    def msgId(self):
1079        return self.msgCtorFunc() + "__ID"
1080
1081    def pqMsgId(self):
1082        return "%s::%s" % (self.namespace, self.msgId())
1083
1084    def replyCtorFunc(self):
1085        return "Reply_%s" % (self.decl.progname)
1086
1087    def pqReplyCtorFunc(self):
1088        return "%s::%s" % (self.namespace, self.replyCtorFunc())
1089
1090    def replyId(self):
1091        return self.replyCtorFunc() + "__ID"
1092
1093    def pqReplyId(self):
1094        return "%s::%s" % (self.namespace, self.replyId())
1095
1096    def prettyReplyName(self, pfx=""):
1097        return pfx + self.replyCtorFunc()
1098
1099    def promiseName(self):
1100        name = self.baseName()
1101        if self.decl.type.isCtor():
1102            name += "Constructor"
1103        name += "Promise"
1104        return name
1105
1106    def resolverName(self):
1107        return self.baseName() + "Resolver"
1108
1109    def actorDecl(self):
1110        return self.params[0]
1111
1112    def makeCxxParams(
1113        self, paramsems="in", returnsems="out", side=None, implicit=True, direction=None
1114    ):
1115        """Return a list of C++ decls per the spec'd configuration.
1116        |params| and |returns| is the C++ semantics of those: 'in', 'out', or None."""
1117
1118        def makeDecl(d, sems):
1119            if (
1120                self.decl.type.tainted
1121                and "NoTaint" not in d.attributes
1122                and direction == "recv"
1123            ):
1124                # Tainted types are passed by-value, allowing the receiver to move them if desired.
1125                assert sems != "out"
1126                return Decl(Type("Tainted", T=d.bareType(side)), d.name)
1127
1128            if sems == "in":
1129                return Decl(d.inType(side), d.name)
1130            elif sems == "move":
1131                return Decl(d.moveType(side), d.name)
1132            elif sems == "out":
1133                return Decl(d.outType(side), d.name)
1134            else:
1135                assert 0
1136
1137        def makeResolverDecl(returns):
1138            return Decl(Type(self.resolverName(), rvalref=True), "aResolve")
1139
1140        def makeCallbackResolveDecl(returns):
1141            if len(returns) > 1:
1142                resolvetype = _tuple([d.bareType(side) for d in returns])
1143            else:
1144                resolvetype = returns[0].bareType(side)
1145
1146            return Decl(
1147                Type("mozilla::ipc::ResolveCallback", T=resolvetype, rvalref=True),
1148                "aResolve",
1149            )
1150
1151        def makeCallbackRejectDecl(returns):
1152            return Decl(Type("mozilla::ipc::RejectCallback", rvalref=True), "aReject")
1153
1154        cxxparams = []
1155        if paramsems is not None:
1156            cxxparams.extend([makeDecl(d, paramsems) for d in self.params])
1157
1158        if returnsems == "promise" and self.returns:
1159            pass
1160        elif returnsems == "callback" and self.returns:
1161            cxxparams.extend(
1162                [
1163                    makeCallbackResolveDecl(self.returns),
1164                    makeCallbackRejectDecl(self.returns),
1165                ]
1166            )
1167        elif returnsems == "resolver" and self.returns:
1168            cxxparams.extend([makeResolverDecl(self.returns)])
1169        elif returnsems is not None:
1170            cxxparams.extend([makeDecl(r, returnsems) for r in self.returns])
1171
1172        if not implicit and self.decl.type.hasImplicitActorParam():
1173            cxxparams = cxxparams[1:]
1174
1175        return cxxparams
1176
1177    def makeCxxArgs(
1178        self, paramsems="in", retsems="out", retcallsems="out", implicit=True
1179    ):
1180        assert not retcallsems or retsems  # retcallsems => returnsems
1181        cxxargs = []
1182
1183        if paramsems == "move":
1184            # We don't std::move() RefPtr<T> types because current Recv*()
1185            # implementors take these parameters as T*, and
1186            # std::move(RefPtr<T>) doesn't coerce to T*.
1187            cxxargs.extend(
1188                [
1189                    p.var() if p.ipdltype.isRefcounted() else ExprMove(p.var())
1190                    for p in self.params
1191                ]
1192            )
1193        elif paramsems == "in":
1194            cxxargs.extend([p.var() for p in self.params])
1195        else:
1196            assert False
1197
1198        for ret in self.returns:
1199            if retsems == "in":
1200                if retcallsems == "in":
1201                    cxxargs.append(ret.var())
1202                elif retcallsems == "out":
1203                    cxxargs.append(ExprAddrOf(ret.var()))
1204                else:
1205                    assert 0
1206            elif retsems == "out":
1207                if retcallsems == "in":
1208                    cxxargs.append(ExprDeref(ret.var()))
1209                elif retcallsems == "out":
1210                    cxxargs.append(ret.var())
1211                else:
1212                    assert 0
1213            elif retsems == "resolver":
1214                pass
1215        if retsems == "resolver":
1216            cxxargs.append(ExprMove(ExprVar("resolver")))
1217
1218        if not implicit:
1219            assert self.decl.type.hasImplicitActorParam()
1220            cxxargs = cxxargs[1:]
1221
1222        return cxxargs
1223
1224    @staticmethod
1225    def upgrade(messageDecl):
1226        assert isinstance(messageDecl, ipdl.ast.MessageDecl)
1227        if messageDecl.decl.type.hasImplicitActorParam():
1228            messageDecl.params.insert(
1229                0,
1230                _HybridDecl(
1231                    ipdl.type.ActorType(messageDecl.decl.type.constructedType()),
1232                    "actor",
1233                ),
1234            )
1235        messageDecl.__class__ = MessageDecl
1236
1237
1238# --------------------------------------------------
1239def _usesShmem(p):
1240    for md in p.messageDecls:
1241        for param in md.inParams:
1242            if ipdl.type.hasshmem(param.type):
1243                return True
1244        for ret in md.outParams:
1245            if ipdl.type.hasshmem(ret.type):
1246                return True
1247    return False
1248
1249
1250def _subtreeUsesShmem(p):
1251    if _usesShmem(p):
1252        return True
1253
1254    ptype = p.decl.type
1255    for mgd in ptype.manages:
1256        if ptype is not mgd:
1257            if _subtreeUsesShmem(mgd._ast):
1258                return True
1259    return False
1260
1261
1262class Protocol(ipdl.ast.Protocol):
1263    def cxxTypedefs(self):
1264        return self.decl.cxxtypedefs
1265
1266    def managerInterfaceType(self, ptr=False):
1267        return Type("mozilla::ipc::IProtocol", ptr=ptr)
1268
1269    def openedProtocolInterfaceType(self, ptr=False):
1270        return Type("mozilla::ipc::IToplevelProtocol", ptr=ptr)
1271
1272    def _ipdlmgrtype(self):
1273        assert 1 == len(self.decl.type.managers)
1274        for mgr in self.decl.type.managers:
1275            return mgr
1276
1277    def managerActorType(self, side, ptr=False):
1278        return Type(_actorName(self._ipdlmgrtype().name(), side), ptr=ptr)
1279
1280    def unregisterMethod(self, actorThis=None):
1281        if actorThis is not None:
1282            return ExprSelect(actorThis, "->", "Unregister")
1283        return ExprVar("Unregister")
1284
1285    def removeManageeMethod(self):
1286        return ExprVar("RemoveManagee")
1287
1288    def deallocManageeMethod(self):
1289        return ExprVar("DeallocManagee")
1290
1291    def otherPidMethod(self):
1292        return ExprVar("OtherPid")
1293
1294    def callOtherPid(self, actorThis=None):
1295        fn = self.otherPidMethod()
1296        if actorThis is not None:
1297            fn = ExprSelect(actorThis, "->", fn.name)
1298        return ExprCall(fn)
1299
1300    def getChannelMethod(self):
1301        return ExprVar("GetIPCChannel")
1302
1303    def callGetChannel(self, actorThis=None):
1304        fn = self.getChannelMethod()
1305        if actorThis is not None:
1306            fn = ExprSelect(actorThis, "->", fn.name)
1307        return ExprCall(fn)
1308
1309    def processingErrorVar(self):
1310        assert self.decl.type.isToplevel()
1311        return ExprVar("ProcessingError")
1312
1313    def shouldContinueFromTimeoutVar(self):
1314        assert self.decl.type.isToplevel()
1315        return ExprVar("ShouldContinueFromReplyTimeout")
1316
1317    def enteredCxxStackVar(self):
1318        assert self.decl.type.isToplevel()
1319        return ExprVar("EnteredCxxStack")
1320
1321    def exitedCxxStackVar(self):
1322        assert self.decl.type.isToplevel()
1323        return ExprVar("ExitedCxxStack")
1324
1325    def enteredCallVar(self):
1326        assert self.decl.type.isToplevel()
1327        return ExprVar("EnteredCall")
1328
1329    def exitedCallVar(self):
1330        assert self.decl.type.isToplevel()
1331        return ExprVar("ExitedCall")
1332
1333    def routingId(self, actorThis=None):
1334        if self.decl.type.isToplevel():
1335            return ExprVar("MSG_ROUTING_CONTROL")
1336        if actorThis is not None:
1337            return ExprCall(ExprSelect(actorThis, "->", "Id"))
1338        return ExprCall(ExprVar("Id"))
1339
1340    def managerVar(self, thisexpr=None):
1341        assert thisexpr is not None or not self.decl.type.isToplevel()
1342        mvar = ExprCall(ExprVar("Manager"), args=[])
1343        if thisexpr is not None:
1344            mvar = ExprCall(ExprSelect(thisexpr, "->", "Manager"), args=[])
1345        return mvar
1346
1347    def managedCxxType(self, actortype, side):
1348        assert self.decl.type.isManagerOf(actortype)
1349        return Type(_actorName(actortype.name(), side), ptr=True)
1350
1351    def managedMethod(self, actortype, side):
1352        assert self.decl.type.isManagerOf(actortype)
1353        return ExprVar("Managed" + _actorName(actortype.name(), side))
1354
1355    def managedVar(self, actortype, side):
1356        assert self.decl.type.isManagerOf(actortype)
1357        return ExprVar("mManaged" + _actorName(actortype.name(), side))
1358
1359    def managedVarType(self, actortype, side, const=False, ref=False):
1360        assert self.decl.type.isManagerOf(actortype)
1361        return _cxxManagedContainerType(
1362            Type(_actorName(actortype.name(), side)), const=const, ref=ref
1363        )
1364
1365    def subtreeUsesShmem(self):
1366        return _subtreeUsesShmem(self)
1367
1368    @staticmethod
1369    def upgrade(protocol):
1370        assert isinstance(protocol, ipdl.ast.Protocol)
1371        protocol.__class__ = Protocol
1372
1373
1374class TranslationUnit(ipdl.ast.TranslationUnit):
1375    @staticmethod
1376    def upgrade(tu):
1377        assert isinstance(tu, ipdl.ast.TranslationUnit)
1378        tu.__class__ = TranslationUnit
1379
1380
1381# -----------------------------------------------------------------------------
1382
1383pod_types = {
1384    "int8_t": 1,
1385    "uint8_t": 1,
1386    "int16_t": 2,
1387    "uint16_t": 2,
1388    "int32_t": 4,
1389    "uint32_t": 4,
1390    "int64_t": 8,
1391    "uint64_t": 8,
1392    "float": 4,
1393    "double": 8,
1394}
1395max_pod_size = max(pod_types.values())
1396# We claim that all types we don't recognize are automatically "bigger"
1397# than pod types for ease of sorting.
1398pod_size_sentinel = max_pod_size * 2
1399
1400
1401def pod_size(ipdltype):
1402    if not isinstance(ipdltype, ipdl.type.ImportedCxxType):
1403        return pod_size_sentinel
1404
1405    return pod_types.get(ipdltype.name(), pod_size_sentinel)
1406
1407
1408class _DecorateWithCxxStuff(ipdl.ast.Visitor):
1409    """Phase 1 of lowering: decorate the IPDL AST with information
1410    relevant to C++ code generation.
1411
1412    This pass results in an AST that is a poor man's "IR"; in reality, a
1413    "hybrid" AST mainly consisting of IPDL nodes with new C++ info along
1414    with some new IPDL/C++ nodes that are tuned for C++ codegen."""
1415
1416    def __init__(self):
1417        self.visitedTus = set()
1418        # the set of typedefs that allow generated classes to
1419        # reference known C++ types by their "short name" rather than
1420        # fully-qualified name. e.g. |Foo| rather than |a::b::Foo|.
1421        self.typedefs = []
1422        self.typedefSet = set(
1423            [
1424                Typedef(Type("mozilla::ipc::ActorHandle"), "ActorHandle"),
1425                Typedef(Type("base::ProcessId"), "ProcessId"),
1426                Typedef(Type("mozilla::ipc::ProtocolId"), "ProtocolId"),
1427                Typedef(Type("mozilla::ipc::Transport"), "Transport"),
1428                Typedef(Type("mozilla::ipc::Endpoint"), "Endpoint", ["FooSide"]),
1429                Typedef(
1430                    Type("mozilla::ipc::ManagedEndpoint"),
1431                    "ManagedEndpoint",
1432                    ["FooSide"],
1433                ),
1434                Typedef(
1435                    Type("mozilla::ipc::TransportDescriptor"), "TransportDescriptor"
1436                ),
1437                Typedef(Type("mozilla::UniquePtr"), "UniquePtr", ["T"]),
1438                Typedef(
1439                    Type("mozilla::ipc::ResponseRejectReason"), "ResponseRejectReason"
1440                ),
1441            ]
1442        )
1443        self.protocolName = None
1444
1445    def visitTranslationUnit(self, tu):
1446        if tu not in self.visitedTus:
1447            self.visitedTus.add(tu)
1448            ipdl.ast.Visitor.visitTranslationUnit(self, tu)
1449            if not isinstance(tu, TranslationUnit):
1450                TranslationUnit.upgrade(tu)
1451            self.typedefs[:] = sorted(list(self.typedefSet))
1452
1453    def visitInclude(self, inc):
1454        if inc.tu.filetype == "header":
1455            inc.tu.accept(self)
1456
1457    def visitProtocol(self, pro):
1458        self.protocolName = pro.name
1459        pro.decl.cxxtypedefs = self.typedefs
1460        Protocol.upgrade(pro)
1461        return ipdl.ast.Visitor.visitProtocol(self, pro)
1462
1463    def visitUsingStmt(self, using):
1464        if using.decl.fullname is not None:
1465            self.typedefSet.add(
1466                Typedef(Type(using.decl.fullname), using.decl.shortname)
1467            )
1468
1469    def visitStructDecl(self, sd):
1470        if not isinstance(sd, StructDecl):
1471            sd.decl.special = False
1472            newfields = []
1473            for f in sd.fields:
1474                ftype = f.decl.type
1475                if _hasVisibleActor(ftype):
1476                    sd.decl.special = True
1477                    # if ftype has a visible actor, we need both
1478                    # |ActorParent| and |ActorChild| fields
1479                    newfields.append(_StructField(ftype, f.name, sd, side="parent"))
1480                    newfields.append(_StructField(ftype, f.name, sd, side="child"))
1481                else:
1482                    newfields.append(_StructField(ftype, f.name, sd))
1483
1484            # Compute a permutation of the fields for in-memory storage such
1485            # that the memory layout of the structure will be well-packed.
1486            permutation = list(range(len(newfields)))
1487
1488            # Note that the results of `pod_size` ensure that non-POD fields
1489            # sort before POD ones.
1490            def size(idx):
1491                return pod_size(newfields[idx].ipdltype)
1492
1493            permutation.sort(key=size, reverse=True)
1494
1495            sd.fields = newfields
1496            sd.packed_field_order = permutation
1497            StructDecl.upgrade(sd)
1498
1499        if sd.decl.fullname is not None:
1500            self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name))
1501
1502    def visitUnionDecl(self, ud):
1503        ud.decl.special = False
1504        newcomponents = []
1505        for ctype in ud.decl.type.components:
1506            if _hasVisibleActor(ctype):
1507                ud.decl.special = True
1508                # if ctype has a visible actor, we need both
1509                # |ActorParent| and |ActorChild| union members
1510                newcomponents.append(_UnionMember(ctype, ud, side="parent"))
1511                newcomponents.append(_UnionMember(ctype, ud, side="child"))
1512            else:
1513                newcomponents.append(_UnionMember(ctype, ud))
1514        ud.components = newcomponents
1515        UnionDecl.upgrade(ud)
1516
1517        if ud.decl.fullname is not None:
1518            self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name))
1519
1520    def visitDecl(self, decl):
1521        return _HybridDecl(decl.type, decl.progname, decl.attributes)
1522
1523    def visitMessageDecl(self, md):
1524        md.namespace = self.protocolName
1525        md.params = [param.accept(self) for param in md.inParams]
1526        md.returns = [ret.accept(self) for ret in md.outParams]
1527        MessageDecl.upgrade(md)
1528
1529
1530# -----------------------------------------------------------------------------
1531
1532
1533def msgenums(protocol, pretty=False):
1534    msgenum = TypeEnum("MessageType")
1535    msgstart = _messageStartName(protocol.decl.type) + " << 16"
1536    msgenum.addId(protocol.name + "Start", msgstart)
1537
1538    for md in protocol.messageDecls:
1539        msgenum.addId(md.prettyMsgName() if pretty else md.msgId())
1540        if md.hasReply():
1541            msgenum.addId(md.prettyReplyName() if pretty else md.replyId())
1542
1543    msgenum.addId(protocol.name + "End")
1544    return msgenum
1545
1546
1547class _GenerateProtocolCode(ipdl.ast.Visitor):
1548    """Creates code common to both the parent and child actors."""
1549
1550    def __init__(self):
1551        self.protocol = None  # protocol we're generating a class for
1552        self.hdrfile = None  # what will become Protocol.h
1553        self.cppfile = None  # what will become Protocol.cpp
1554        self.cppIncludeHeaders = []
1555        self.structUnionDefns = []
1556        self.funcDefns = []
1557
1558    def lower(self, tu, cxxHeaderFile, cxxFile, segmentcapacitydict):
1559        self.protocol = tu.protocol
1560        self.hdrfile = cxxHeaderFile
1561        self.cppfile = cxxFile
1562        self.segmentcapacitydict = segmentcapacitydict
1563        tu.accept(self)
1564
1565    def visitTranslationUnit(self, tu):
1566        hf = self.hdrfile
1567
1568        hf.addthing(_DISCLAIMER)
1569        hf.addthings(_includeGuardStart(hf))
1570        hf.addthing(Whitespace.NL)
1571
1572        for inc in builtinHeaderIncludes:
1573            self.visitBuiltinCxxInclude(inc)
1574
1575        # Compute the set of includes we need for declared structure/union
1576        # classes for this protocol.
1577        typesToIncludes = {}
1578        for using in tu.using:
1579            typestr = str(using.type.spec)
1580            if typestr not in typesToIncludes:
1581                typesToIncludes[typestr] = using.header
1582            else:
1583                assert typesToIncludes[typestr] == using.header
1584
1585        aggregateTypeIncludes = set()
1586        for su in tu.structsAndUnions:
1587            typedeps = _ComputeTypeDeps(su.decl.type, True)
1588            if isinstance(su, ipdl.ast.StructDecl):
1589                for f in su.fields:
1590                    f.ipdltype.accept(typedeps)
1591            elif isinstance(su, ipdl.ast.UnionDecl):
1592                for c in su.components:
1593                    c.ipdltype.accept(typedeps)
1594
1595            for typename in [t.fromtype.name for t in typedeps.usingTypedefs]:
1596                if typename in typesToIncludes:
1597                    aggregateTypeIncludes.add(typesToIncludes[typename])
1598
1599        if len(aggregateTypeIncludes) != 0:
1600            hf.addthing(Whitespace.NL)
1601            hf.addthings([Whitespace("// Headers for typedefs"), Whitespace.NL])
1602
1603            for headername in sorted(iter(aggregateTypeIncludes)):
1604                hf.addthing(CppDirective("include", '"' + headername + '"'))
1605
1606        # Manually run Visitor.visitTranslationUnit. For dependency resolution
1607        # we need to handle structs and unions separately.
1608        for cxxInc in tu.cxxIncludes:
1609            cxxInc.accept(self)
1610        for inc in tu.includes:
1611            inc.accept(self)
1612        self.generateStructsAndUnions(tu)
1613        for using in tu.builtinUsing:
1614            using.accept(self)
1615        for using in tu.using:
1616            using.accept(self)
1617        if tu.protocol:
1618            tu.protocol.accept(self)
1619
1620        if tu.filetype == "header":
1621            self.cppIncludeHeaders.append(_ipdlhHeaderName(tu) + ".h")
1622
1623        hf.addthing(Whitespace.NL)
1624        hf.addthings(_includeGuardEnd(hf))
1625
1626        cf = self.cppfile
1627        cf.addthings(
1628            (
1629                [_DISCLAIMER, Whitespace.NL]
1630                + [
1631                    CppDirective("include", '"' + h + '"')
1632                    for h in self.cppIncludeHeaders
1633                ]
1634                + [Whitespace.NL]
1635                + [
1636                    CppDirective("include", '"%s"' % filename)
1637                    for filename in ipdl.builtin.CppIncludes
1638                ]
1639                + [Whitespace.NL]
1640            )
1641        )
1642
1643        if self.protocol:
1644            # construct the namespace into which we'll stick all our defns
1645            ns = Namespace(self.protocol.name)
1646            cf.addthing(_putInNamespaces(ns, self.protocol.namespaces))
1647            ns.addstmts(([Whitespace.NL] + self.funcDefns + [Whitespace.NL]))
1648
1649        cf.addthings(self.structUnionDefns)
1650
1651    def visitBuiltinCxxInclude(self, inc):
1652        self.hdrfile.addthing(CppDirective("include", '"' + inc.file + '"'))
1653
1654    def visitCxxInclude(self, inc):
1655        self.cppIncludeHeaders.append(inc.file)
1656
1657    def visitInclude(self, inc):
1658        if inc.tu.filetype == "header":
1659            self.hdrfile.addthing(
1660                CppDirective("include", '"' + _ipdlhHeaderName(inc.tu) + '.h"')
1661            )
1662        else:
1663            self.cppIncludeHeaders += [
1664                _protocolHeaderName(inc.tu.protocol, "parent") + ".h",
1665                _protocolHeaderName(inc.tu.protocol, "child") + ".h",
1666            ]
1667
1668    def generateStructsAndUnions(self, tu):
1669        """Generate the definitions for all structs and unions. This will
1670        re-order the declarations if needed in the C++ code such that
1671        dependencies have already been defined."""
1672        decls = OrderedDict()
1673        for su in tu.structsAndUnions:
1674            if isinstance(su, StructDecl):
1675                which = "struct"
1676                forwarddecls, fulldecltypes, cls = _generateCxxStruct(su)
1677                traitsdecl, traitsdefns = _ParamTraits.structPickling(su.decl.type)
1678            else:
1679                assert isinstance(su, UnionDecl)
1680                which = "union"
1681                forwarddecls, fulldecltypes, cls = _generateCxxUnion(su)
1682                traitsdecl, traitsdefns = _ParamTraits.unionPickling(su.decl.type)
1683
1684            clsdecl, methoddefns = _splitClassDeclDefn(cls)
1685
1686            # Store the declarations in the decls map so we can emit in
1687            # dependency order.
1688            decls[su.decl.type] = (
1689                fulldecltypes,
1690                [Whitespace.NL]
1691                + forwarddecls
1692                + [
1693                    Whitespace(
1694                        """
1695//-----------------------------------------------------------------------------
1696// Declaration of the IPDL type |%s %s|
1697//
1698"""
1699                        % (which, su.name)
1700                    ),
1701                    _putInNamespaces(clsdecl, su.namespaces),
1702                ]
1703                + [Whitespace.NL, traitsdecl],
1704            )
1705
1706            self.structUnionDefns.extend(
1707                [
1708                    Whitespace(
1709                        """
1710//-----------------------------------------------------------------------------
1711// Method definitions for the IPDL type |%s %s|
1712//
1713"""
1714                        % (which, su.name)
1715                    ),
1716                    _putInNamespaces(methoddefns, su.namespaces),
1717                    Whitespace.NL,
1718                    traitsdefns,
1719                ]
1720            )
1721
1722        # Generate the declarations structs in dependency order.
1723        def gen_struct(deps, defn):
1724            for dep in deps:
1725                if dep in decls:
1726                    d, t = decls[dep]
1727                    del decls[dep]
1728                    gen_struct(d, t)
1729            self.hdrfile.addthings(defn)
1730
1731        while len(decls) > 0:
1732            _, (d, t) = decls.popitem(False)
1733            gen_struct(d, t)
1734
1735    def visitProtocol(self, p):
1736        self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, "") + ".h")
1737        self.cppIncludeHeaders.append(
1738            _protocolHeaderName(self.protocol, "Parent") + ".h"
1739        )
1740        self.cppIncludeHeaders.append(
1741            _protocolHeaderName(self.protocol, "Child") + ".h"
1742        )
1743
1744        # Forward declare our own actors.
1745        self.hdrfile.addthings(
1746            [
1747                Whitespace.NL,
1748                _makeForwardDeclForActor(p.decl.type, "Parent"),
1749                _makeForwardDeclForActor(p.decl.type, "Child"),
1750            ]
1751        )
1752
1753        self.hdrfile.addthing(
1754            Whitespace(
1755                """
1756//-----------------------------------------------------------------------------
1757// Code common to %sChild and %sParent
1758//
1759"""
1760                % (p.name, p.name)
1761            )
1762        )
1763
1764        # construct the namespace into which we'll stick all our decls
1765        ns = Namespace(self.protocol.name)
1766        self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces))
1767        ns.addstmt(Whitespace.NL)
1768
1769        edecl, edefn = _splitFuncDeclDefn(self.genEndpointFunc())
1770        ns.addstmts([edecl, Whitespace.NL])
1771        self.funcDefns.append(edefn)
1772
1773        # spit out message type enum and classes
1774        msgenum = msgenums(self.protocol)
1775        ns.addstmts([StmtDecl(Decl(msgenum, "")), Whitespace.NL])
1776
1777        for md in p.messageDecls:
1778            decls = []
1779
1780            # Look up the segment capacity used for serializing this
1781            # message. If the capacity is not specified, use '0' for
1782            # the default capacity (defined in ipc_message.cc)
1783            name = "%s::%s" % (md.namespace, md.decl.progname)
1784            segmentcapacity = self.segmentcapacitydict.get(name, 0)
1785
1786            mfDecl, mfDefn = _splitFuncDeclDefn(
1787                _generateMessageConstructor(md, segmentcapacity, p, forReply=False)
1788            )
1789            decls.append(mfDecl)
1790            self.funcDefns.append(mfDefn)
1791
1792            if md.hasReply():
1793                rfDecl, rfDefn = _splitFuncDeclDefn(
1794                    _generateMessageConstructor(md, 0, p, forReply=True)
1795                )
1796                decls.append(rfDecl)
1797                self.funcDefns.append(rfDefn)
1798
1799            decls.append(Whitespace.NL)
1800            ns.addstmts(decls)
1801
1802        ns.addstmts([Whitespace.NL, Whitespace.NL])
1803
1804    # Generate code for PFoo::CreateEndpoints.
1805    def genEndpointFunc(self):
1806        p = self.protocol.decl.type
1807        tparent = _cxxBareType(ActorType(p), "Parent", fq=True)
1808        tchild = _cxxBareType(ActorType(p), "Child", fq=True)
1809
1810        openfunc = MethodDefn(
1811            MethodDecl(
1812                "CreateEndpoints",
1813                params=[
1814                    Decl(Type("base::ProcessId"), "aParentDestPid"),
1815                    Decl(Type("base::ProcessId"), "aChildDestPid"),
1816                    Decl(
1817                        Type("mozilla::ipc::Endpoint<" + tparent.name + ">", ptr=True),
1818                        "aParent",
1819                    ),
1820                    Decl(
1821                        Type("mozilla::ipc::Endpoint<" + tchild.name + ">", ptr=True),
1822                        "aChild",
1823                    ),
1824                ],
1825                ret=Type.NSRESULT,
1826            )
1827        )
1828        openfunc.addcode(
1829            """
1830            return mozilla::ipc::CreateEndpoints(
1831                mozilla::ipc::PrivateIPDLInterface(),
1832                aParentDestPid, aChildDestPid,
1833                aParent, aChild);
1834            """
1835        )
1836        return openfunc
1837
1838
1839# --------------------------------------------------
1840
1841
1842def _generateMessageConstructor(md, segmentSize, protocol, forReply=False):
1843    if forReply:
1844        clsname = md.replyCtorFunc()
1845        msgid = md.replyId()
1846        replyEnum = "REPLY"
1847    else:
1848        clsname = md.msgCtorFunc()
1849        msgid = md.msgId()
1850        replyEnum = "NOT_REPLY"
1851
1852    nested = md.decl.type.nested
1853    prio = md.decl.type.prio
1854    compress = md.decl.type.compress
1855
1856    routingId = ExprVar("routingId")
1857
1858    func = FunctionDefn(
1859        FunctionDecl(
1860            clsname,
1861            params=[Decl(Type("int32_t"), routingId.name)],
1862            ret=Type("IPC::Message", ptr=True),
1863        )
1864    )
1865
1866    if not compress:
1867        compression = "COMPRESSION_NONE"
1868    elif compress.value == "all":
1869        compression = "COMPRESSION_ALL"
1870    else:
1871        assert compress.value is None
1872        compression = "COMPRESSION_ENABLED"
1873
1874    if nested == ipdl.ast.NOT_NESTED:
1875        nestedEnum = "NOT_NESTED"
1876    elif nested == ipdl.ast.INSIDE_SYNC_NESTED:
1877        nestedEnum = "NESTED_INSIDE_SYNC"
1878    else:
1879        assert nested == ipdl.ast.INSIDE_CPOW_NESTED
1880        nestedEnum = "NESTED_INSIDE_CPOW"
1881
1882    if prio == ipdl.ast.NORMAL_PRIORITY:
1883        prioEnum = "NORMAL_PRIORITY"
1884    elif prio == ipdl.ast.INPUT_PRIORITY:
1885        prioEnum = "INPUT_PRIORITY"
1886    elif prio == ipdl.ast.VSYNC_PRIORITY:
1887        prioEnum = "VSYNC_PRIORITY"
1888    elif prio == ipdl.ast.MEDIUMHIGH_PRIORITY:
1889        prioEnum = "MEDIUMHIGH_PRIORITY"
1890    else:
1891        prioEnum = "CONTROL_PRIORITY"
1892
1893    if md.decl.type.isSync():
1894        syncEnum = "SYNC"
1895    else:
1896        syncEnum = "ASYNC"
1897
1898    if md.decl.type.isInterrupt():
1899        interruptEnum = "INTERRUPT"
1900    else:
1901        interruptEnum = "NOT_INTERRUPT"
1902
1903    if md.decl.type.isCtor():
1904        ctorEnum = "CONSTRUCTOR"
1905    else:
1906        ctorEnum = "NOT_CONSTRUCTOR"
1907
1908    def messageEnum(valname):
1909        return ExprVar("IPC::Message::" + valname)
1910
1911    flags = ExprCall(
1912        ExprVar("IPC::Message::HeaderFlags"),
1913        args=[
1914            messageEnum(nestedEnum),
1915            messageEnum(prioEnum),
1916            messageEnum(compression),
1917            messageEnum(ctorEnum),
1918            messageEnum(syncEnum),
1919            messageEnum(interruptEnum),
1920            messageEnum(replyEnum),
1921        ],
1922    )
1923
1924    segmentSize = int(segmentSize)
1925    if segmentSize:
1926        func.addstmt(
1927            StmtReturn(
1928                ExprNew(
1929                    Type("IPC::Message"),
1930                    args=[
1931                        routingId,
1932                        ExprVar(msgid),
1933                        ExprLiteral.Int(int(segmentSize)),
1934                        flags,
1935                        # Pass `true` to recordWriteLatency to collect telemetry
1936                        ExprLiteral.TRUE,
1937                    ],
1938                )
1939            )
1940        )
1941    else:
1942        func.addstmt(
1943            StmtReturn(
1944                ExprCall(
1945                    ExprVar("IPC::Message::IPDLMessage"),
1946                    args=[routingId, ExprVar(msgid), flags],
1947                )
1948            )
1949        )
1950
1951    return func
1952
1953
1954# --------------------------------------------------
1955
1956
1957class _ParamTraits:
1958    var = ExprVar("aVar")
1959    msgvar = ExprVar("aMsg")
1960    itervar = ExprVar("aIter")
1961    actor = ExprVar("aActor")
1962
1963    @classmethod
1964    def ifsideis(cls, side, then, els=None):
1965        cxxside = ExprVar("mozilla::ipc::ChildSide")
1966        if side == "parent":
1967            cxxside = ExprVar("mozilla::ipc::ParentSide")
1968
1969        ifstmt = StmtIf(
1970            ExprBinary(cxxside, "==", ExprCall(ExprSelect(cls.actor, "->", "GetSide")))
1971        )
1972        ifstmt.addifstmt(then)
1973        if els is not None:
1974            ifstmt.addelsestmt(els)
1975        return ifstmt
1976
1977    @classmethod
1978    def fatalError(cls, reason):
1979        return StmtCode(
1980            "aActor->FatalError(${reason});", reason=ExprLiteral.String(reason)
1981        )
1982
1983    @classmethod
1984    def writeSentinel(cls, msgvar, sentinelKey):
1985        return [
1986            Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True),
1987            StmtExpr(
1988                ExprCall(
1989                    ExprSelect(msgvar, "->", "WriteSentinel"),
1990                    args=[ExprLiteral.Int(hashfunc(sentinelKey))],
1991                )
1992            ),
1993        ]
1994
1995    @classmethod
1996    def readSentinel(cls, msgvar, itervar, sentinelKey, sentinelFail):
1997        # Read the sentinel
1998        read = ExprCall(
1999            ExprSelect(msgvar, "->", "ReadSentinel"),
2000            args=[itervar, ExprLiteral.Int(hashfunc(sentinelKey))],
2001        )
2002        ifsentinel = StmtIf(ExprNot(read))
2003        ifsentinel.addifstmts(sentinelFail)
2004
2005        return [
2006            Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True),
2007            ifsentinel,
2008        ]
2009
2010    @classmethod
2011    def write(cls, var, msgvar, actor, ipdltype=None):
2012        # WARNING: This doesn't set AutoForActor for you, make sure this is
2013        # only called when the actor is already correctly set.
2014        if ipdltype and _cxxTypeNeedsMoveForSend(ipdltype):
2015            var = ExprMove(var)
2016        return ExprCall(ExprVar("WriteIPDLParam"), args=[msgvar, actor, var])
2017
2018    @classmethod
2019    def checkedWrite(cls, ipdltype, var, msgvar, sentinelKey, actor):
2020        assert sentinelKey
2021        block = Block()
2022
2023        # Assert we aren't serializing a null non-nullable actor
2024        if (
2025            ipdltype
2026            and ipdltype.isIPDL()
2027            and ipdltype.isActor()
2028            and not ipdltype.nullable
2029        ):
2030            block.addstmt(
2031                _abortIfFalse(var, "NULL actor value passed to non-nullable param")
2032            )
2033
2034        block.addstmts(
2035            [
2036                StmtExpr(cls.write(var, msgvar, actor, ipdltype)),
2037            ]
2038        )
2039        block.addstmts(cls.writeSentinel(msgvar, sentinelKey))
2040        return block
2041
2042    @classmethod
2043    def bulkSentinelKey(cls, fields):
2044        return " | ".join(f.basename for f in fields)
2045
2046    @classmethod
2047    def checkedBulkWrite(cls, size, fields):
2048        block = Block()
2049        first = fields[0]
2050
2051        block.addstmts(
2052            [
2053                StmtExpr(
2054                    ExprCall(
2055                        ExprSelect(cls.msgvar, "->", "WriteBytes"),
2056                        args=[
2057                            ExprAddrOf(
2058                                ExprCall(first.getMethod(thisexpr=cls.var, sel="."))
2059                            ),
2060                            ExprLiteral.Int(size * len(fields)),
2061                        ],
2062                    )
2063                )
2064            ]
2065        )
2066        block.addstmts(cls.writeSentinel(cls.msgvar, cls.bulkSentinelKey(fields)))
2067
2068        return block
2069
2070    @classmethod
2071    def checkedBulkRead(cls, size, fields):
2072        block = Block()
2073        first = fields[0]
2074
2075        readbytes = ExprCall(
2076            ExprSelect(cls.msgvar, "->", "ReadBytesInto"),
2077            args=[
2078                cls.itervar,
2079                ExprAddrOf(ExprCall(first.getMethod(thisexpr=cls.var, sel="->"))),
2080                ExprLiteral.Int(size * len(fields)),
2081            ],
2082        )
2083        ifbad = StmtIf(ExprNot(readbytes))
2084        errmsg = "Error bulk reading fields from %s" % first.ipdltype.name()
2085        ifbad.addifstmts([cls.fatalError(errmsg), StmtReturn.FALSE])
2086        block.addstmt(ifbad)
2087        block.addstmts(
2088            cls.readSentinel(
2089                cls.msgvar,
2090                cls.itervar,
2091                cls.bulkSentinelKey(fields),
2092                errfnSentinel()(errmsg),
2093            )
2094        )
2095
2096        return block
2097
2098    @classmethod
2099    def checkedRead(
2100        cls,
2101        ipdltype,
2102        var,
2103        msgvar,
2104        itervar,
2105        errfn,
2106        paramtype,
2107        sentinelKey,
2108        errfnSentinel,
2109        actor,
2110    ):
2111        block = Block()
2112
2113        # Read the data
2114        ifbad = StmtIf(
2115            ExprNot(
2116                ExprCall(ExprVar("ReadIPDLParam"), args=[msgvar, itervar, actor, var])
2117            )
2118        )
2119        if not isinstance(paramtype, list):
2120            paramtype = ["Error deserializing " + paramtype]
2121        ifbad.addifstmts(errfn(*paramtype))
2122        block.addstmt(ifbad)
2123
2124        # Check if we got a null non-nullable actor
2125        if (
2126            ipdltype
2127            and ipdltype.isIPDL()
2128            and ipdltype.isActor()
2129            and not ipdltype.nullable
2130        ):
2131            ifnull = StmtIf(ExprNot(ExprDeref(var)))
2132            ifnull.addifstmts(errfn(*paramtype))
2133            block.addstmt(ifnull)
2134
2135        block.addstmts(
2136            cls.readSentinel(msgvar, itervar, sentinelKey, errfnSentinel(*paramtype))
2137        )
2138
2139        return block
2140
2141    # Helper wrapper for checkedRead for use within _ParamTraits
2142    @classmethod
2143    def _checkedRead(cls, ipdltype, var, sentinelKey, what):
2144        def errfn(msg):
2145            return [cls.fatalError(msg), StmtReturn.FALSE]
2146
2147        return cls.checkedRead(
2148            ipdltype,
2149            var,
2150            cls.msgvar,
2151            cls.itervar,
2152            errfn=errfn,
2153            paramtype=what,
2154            sentinelKey=sentinelKey,
2155            errfnSentinel=errfnSentinel(),
2156            actor=cls.actor,
2157        )
2158
2159    @classmethod
2160    def generateDecl(cls, fortype, write, read, constin=True):
2161        # IPDLParamTraits impls are selected ignoring constness, and references.
2162        pt = Class(
2163            "IPDLParamTraits",
2164            specializes=Type(
2165                fortype.name, T=fortype.T, inner=fortype.inner, ptr=fortype.ptr
2166            ),
2167            struct=True,
2168        )
2169
2170        # typedef T paramType;
2171        pt.addstmt(Typedef(fortype, "paramType"))
2172
2173        iprotocoltype = Type("mozilla::ipc::IProtocol", ptr=True)
2174
2175        # static void Write(Message*, const T&);
2176        intype = Type("paramType", ref=True, const=constin)
2177        writemthd = MethodDefn(
2178            MethodDecl(
2179                "Write",
2180                params=[
2181                    Decl(Type("IPC::Message", ptr=True), cls.msgvar.name),
2182                    Decl(iprotocoltype, cls.actor.name),
2183                    Decl(intype, cls.var.name),
2184                ],
2185                methodspec=MethodSpec.STATIC,
2186            )
2187        )
2188        writemthd.addstmts(write)
2189        pt.addstmt(writemthd)
2190
2191        # static bool Read(const Message*, PickleIterator*, T*);
2192        outtype = Type("paramType", ptr=True)
2193        readmthd = MethodDefn(
2194            MethodDecl(
2195                "Read",
2196                params=[
2197                    Decl(Type("IPC::Message", ptr=True, const=True), cls.msgvar.name),
2198                    Decl(_iterType(ptr=True), cls.itervar.name),
2199                    Decl(iprotocoltype, cls.actor.name),
2200                    Decl(outtype, cls.var.name),
2201                ],
2202                ret=Type.BOOL,
2203                methodspec=MethodSpec.STATIC,
2204            )
2205        )
2206        readmthd.addstmts(read)
2207        pt.addstmt(readmthd)
2208
2209        # Split the class into declaration and definition
2210        clsdecl, methoddefns = _splitClassDeclDefn(pt)
2211
2212        namespaces = [Namespace("mozilla"), Namespace("ipc")]
2213        clsns = _putInNamespaces(clsdecl, namespaces)
2214        defns = _putInNamespaces(methoddefns, namespaces)
2215        return clsns, defns
2216
2217    @classmethod
2218    def actorPickling(cls, actortype, side):
2219        """Generates pickling for IPDL actors. This is a |nullable| deserializer.
2220        Write and read callers will perform nullability validation."""
2221
2222        cxxtype = _cxxBareType(actortype, side, fq=True)
2223
2224        write = StmtCode(
2225            """
2226            int32_t id;
2227            if (!${var}) {
2228                id = 0;  // kNullActorId
2229            } else {
2230                id = ${var}->Id();
2231                if (id == 1) {  // kFreedActorId
2232                    ${var}->FatalError("Actor has been |delete|d");
2233                }
2234                MOZ_RELEASE_ASSERT(
2235                    ${actor}->GetIPCChannel() == ${var}->GetIPCChannel(),
2236                    "Actor must be from the same channel as the"
2237                    " actor it's being sent over");
2238                MOZ_RELEASE_ASSERT(
2239                    ${var}->CanSend(),
2240                    "Actor must still be open when sending");
2241            }
2242
2243            ${write};
2244            """,
2245            var=cls.var,
2246            actor=cls.actor,
2247            write=cls.write(ExprVar("id"), cls.msgvar, cls.actor),
2248        )
2249
2250        # bool Read(..) impl
2251        read = StmtCode(
2252            """
2253            mozilla::Maybe<mozilla::ipc::IProtocol*> actor =
2254                ${actor}->ReadActor(${msgvar}, ${itervar}, true, ${actortype}, ${protocolid});
2255            if (actor.isNothing()) {
2256                return false;
2257            }
2258
2259            *${var} = static_cast<${cxxtype}>(actor.value());
2260            return true;
2261            """,
2262            actor=cls.actor,
2263            msgvar=cls.msgvar,
2264            itervar=cls.itervar,
2265            actortype=ExprLiteral.String(actortype.name()),
2266            protocolid=_protocolId(actortype),
2267            var=cls.var,
2268            cxxtype=cxxtype,
2269        )
2270
2271        return cls.generateDecl(cxxtype, [write], [read])
2272
2273    @classmethod
2274    def structPickling(cls, structtype):
2275        sd = structtype._ast
2276        # NOTE: Not using _cxxBareType here as we don't have a side
2277        cxxtype = Type(structtype.fullname())
2278
2279        def get(sel, f):
2280            return ExprCall(f.getMethod(thisexpr=cls.var, sel=sel))
2281
2282        write = []
2283        read = []
2284
2285        for (size, fields) in itertools.groupby(
2286            sd.fields_member_order(), lambda f: pod_size(f.ipdltype)
2287        ):
2288            fields = list(fields)
2289
2290            if size == pod_size_sentinel:
2291                for f in fields:
2292                    writefield = cls.checkedWrite(
2293                        f.ipdltype,
2294                        get(".", f),
2295                        cls.msgvar,
2296                        sentinelKey=f.basename,
2297                        actor=cls.actor,
2298                    )
2299                    readfield = cls._checkedRead(
2300                        f.ipdltype,
2301                        ExprAddrOf(get("->", f)),
2302                        f.basename,
2303                        "'"
2304                        + f.getMethod().name
2305                        + "' "
2306                        + "("
2307                        + f.ipdltype.name()
2308                        + ") member of "
2309                        + "'"
2310                        + structtype.name()
2311                        + "'",
2312                    )
2313
2314                    # Wrap the read/write in a side check if the field is special.
2315                    if f.special:
2316                        writefield = cls.ifsideis(f.side, writefield)
2317                        readfield = cls.ifsideis(f.side, readfield)
2318
2319                    write.append(writefield)
2320                    read.append(readfield)
2321            else:
2322                for f in fields:
2323                    assert not f.special
2324
2325                writefield = cls.checkedBulkWrite(size, fields)
2326                readfield = cls.checkedBulkRead(size, fields)
2327
2328                write.append(writefield)
2329                read.append(readfield)
2330
2331        read.append(StmtReturn.TRUE)
2332
2333        return cls.generateDecl(cxxtype, write, read)
2334
2335    @classmethod
2336    def unionPickling(cls, uniontype):
2337        # NOTE: Not using _cxxBareType here as we don't have a side
2338        cxxtype = Type(uniontype.fullname())
2339        ud = uniontype._ast
2340
2341        # Use typedef to set up an alias so it's easier to reference the struct type.
2342        alias = "union__"
2343        typevar = ExprVar("type")
2344
2345        prelude = [
2346            Typedef(cxxtype, alias),
2347        ]
2348
2349        writeswitch = StmtSwitch(typevar)
2350        write = prelude + [
2351            StmtDecl(Decl(Type.INT, typevar.name), init=ud.callType(cls.var)),
2352            cls.checkedWrite(
2353                None, typevar, cls.msgvar, sentinelKey=uniontype.name(), actor=cls.actor
2354            ),
2355            Whitespace.NL,
2356            writeswitch,
2357        ]
2358
2359        readswitch = StmtSwitch(typevar)
2360        read = prelude + [
2361            StmtDecl(Decl(Type.INT, typevar.name), init=ExprLiteral.ZERO),
2362            cls._checkedRead(
2363                None,
2364                ExprAddrOf(typevar),
2365                uniontype.name(),
2366                "type of union " + uniontype.name(),
2367            ),
2368            Whitespace.NL,
2369            readswitch,
2370        ]
2371
2372        for c in ud.components:
2373            ct = c.ipdltype
2374            caselabel = CaseLabel(alias + "::" + c.enum())
2375            origenum = c.enum()
2376
2377            writecase = StmtBlock()
2378            wstmt = cls.checkedWrite(
2379                c.ipdltype,
2380                ExprCall(ExprSelect(cls.var, ".", c.getTypeName())),
2381                cls.msgvar,
2382                sentinelKey=c.enum(),
2383                actor=cls.actor,
2384            )
2385            if c.special:
2386                # Report an error if the type is special and the side is wrong
2387                wstmt = cls.ifsideis(c.side, wstmt, els=cls.fatalError("wrong side!"))
2388            writecase.addstmts([wstmt, StmtReturn()])
2389            writeswitch.addcase(caselabel, writecase)
2390
2391            readcase = StmtBlock()
2392            if c.special:
2393                # The type comes across flipped from what the actor will be on
2394                # this side; i.e. child->parent messages will have PFooChild
2395                # when received on the parent side. Report an error if the sides
2396                # match, and handle c.other instead.
2397                readcase.addstmt(
2398                    cls.ifsideis(
2399                        c.side,
2400                        StmtBlock([cls.fatalError("wrong side!"), StmtReturn.FALSE]),
2401                    )
2402                )
2403                c = c.other
2404            tmpvar = ExprVar("tmp")
2405            ct = c.bareType(fq=True)
2406            readcase.addstmts(
2407                [
2408                    StmtDecl(Decl(ct, tmpvar.name), init=c.defaultValue(fq=True)),
2409                    StmtExpr(ExprAssn(ExprDeref(cls.var), ExprMove(tmpvar))),
2410                    cls._checkedRead(
2411                        c.ipdltype,
2412                        ExprAddrOf(
2413                            ExprCall(ExprSelect(cls.var, "->", c.getTypeName()))
2414                        ),
2415                        origenum,
2416                        "variant " + origenum + " of union " + uniontype.name(),
2417                    ),
2418                    StmtReturn.TRUE,
2419                ]
2420            )
2421            readswitch.addcase(caselabel, readcase)
2422
2423        # Add the error default case
2424        writeswitch.addcase(
2425            DefaultLabel(),
2426            StmtBlock([cls.fatalError("unknown union type"), StmtReturn()]),
2427        )
2428        readswitch.addcase(
2429            DefaultLabel(),
2430            StmtBlock([cls.fatalError("unknown union type"), StmtReturn.FALSE]),
2431        )
2432
2433        return cls.generateDecl(cxxtype, write, read)
2434
2435
2436# --------------------------------------------------
2437
2438
2439class _ComputeTypeDeps(TypeVisitor):
2440    """Pass that gathers the C++ types that a particular IPDL type
2441    (recursively) depends on.  There are three kinds of dependencies: (i)
2442    types that need forward declaration; (ii) types that need a |using|
2443    stmt; (iii) IPDL structs or unions which must be fully declared
2444    before this struct.  Some types generate multiple kinds."""
2445
2446    def __init__(self, fortype, unqualifiedTypedefs=False):
2447        ipdl.type.TypeVisitor.__init__(self)
2448        self.usingTypedefs = []
2449        self.forwardDeclStmts = []
2450        self.fullDeclTypes = []
2451        self.fortype = fortype
2452        self.unqualifiedTypedefs = unqualifiedTypedefs
2453
2454    def maybeTypedef(self, fqname, name, templateargs=[]):
2455        if fqname != name or self.unqualifiedTypedefs:
2456            self.usingTypedefs.append(Typedef(Type(fqname), name, templateargs))
2457
2458    def visitImportedCxxType(self, t):
2459        if t in self.visited:
2460            return
2461        self.visited.add(t)
2462        self.maybeTypedef(t.fullname(), t.name())
2463
2464    def visitActorType(self, t):
2465        if t in self.visited:
2466            return
2467        self.visited.add(t)
2468
2469        fqname, name = t.fullname(), t.name()
2470
2471        self.maybeTypedef(_actorName(fqname, "Parent"), _actorName(name, "Parent"))
2472        self.maybeTypedef(_actorName(fqname, "Child"), _actorName(name, "Child"))
2473
2474        self.forwardDeclStmts.extend(
2475            [
2476                _makeForwardDeclForActor(t.protocol, "parent"),
2477                Whitespace.NL,
2478                _makeForwardDeclForActor(t.protocol, "child"),
2479                Whitespace.NL,
2480            ]
2481        )
2482
2483    def visitStructOrUnionType(self, su, defaultVisit):
2484        if su in self.visited or su == self.fortype:
2485            return
2486        self.visited.add(su)
2487        self.maybeTypedef(su.fullname(), su.name())
2488
2489        # Mutually recursive fields in unions are behind indirection, so we only
2490        # need a forward decl, and don't need a full type declaration.
2491        if isinstance(self.fortype, UnionType) and self.fortype.mutuallyRecursiveWith(
2492            su
2493        ):
2494            self.forwardDeclStmts.append(_makeForwardDecl(su))
2495        else:
2496            self.fullDeclTypes.append(su)
2497
2498        return defaultVisit(self, su)
2499
2500    def visitStructType(self, t):
2501        return self.visitStructOrUnionType(t, TypeVisitor.visitStructType)
2502
2503    def visitUnionType(self, t):
2504        return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType)
2505
2506    def visitArrayType(self, t):
2507        return TypeVisitor.visitArrayType(self, t)
2508
2509    def visitMaybeType(self, m):
2510        return TypeVisitor.visitMaybeType(self, m)
2511
2512    def visitShmemType(self, s):
2513        if s in self.visited:
2514            return
2515        self.visited.add(s)
2516        self.maybeTypedef("mozilla::ipc::Shmem", "Shmem")
2517
2518    def visitByteBufType(self, s):
2519        if s in self.visited:
2520            return
2521        self.visited.add(s)
2522        self.maybeTypedef("mozilla::ipc::ByteBuf", "ByteBuf")
2523
2524    def visitFDType(self, s):
2525        if s in self.visited:
2526            return
2527        self.visited.add(s)
2528        self.maybeTypedef("mozilla::ipc::FileDescriptor", "FileDescriptor")
2529
2530    def visitEndpointType(self, s):
2531        if s in self.visited:
2532            return
2533        self.visited.add(s)
2534        self.maybeTypedef("mozilla::ipc::Endpoint", "Endpoint", ["FooSide"])
2535        self.visitActorType(s.actor)
2536
2537    def visitManagedEndpointType(self, s):
2538        if s in self.visited:
2539            return
2540        self.visited.add(s)
2541        self.maybeTypedef(
2542            "mozilla::ipc::ManagedEndpoint", "ManagedEndpoint", ["FooSide"]
2543        )
2544        self.visitActorType(s.actor)
2545
2546    def visitUniquePtrType(self, s):
2547        if s in self.visited:
2548            return
2549        self.visited.add(s)
2550
2551    def visitVoidType(self, v):
2552        assert 0
2553
2554    def visitMessageType(self, v):
2555        assert 0
2556
2557    def visitProtocolType(self, v):
2558        assert 0
2559
2560
2561def _fieldStaticAssertions(sd):
2562    staticasserts = []
2563    for (size, fields) in itertools.groupby(
2564        sd.fields_member_order(), lambda f: pod_size(f.ipdltype)
2565    ):
2566        if size == pod_size_sentinel:
2567            continue
2568
2569        fields = list(fields)
2570        if len(fields) == 1:
2571            continue
2572
2573        staticasserts.append(
2574            StmtCode(
2575                """
2576            static_assert(
2577                (offsetof(${struct}, ${last}) - offsetof(${struct}, ${first})) == ${expected},
2578                "Bad assumptions about field layout!");
2579            """,
2580                struct=sd.name,
2581                first=fields[0].memberVar(),
2582                last=fields[-1].memberVar(),
2583                expected=ExprLiteral.Int(size * (len(fields) - 1)),
2584            )
2585        )
2586
2587    return staticasserts
2588
2589
2590def _generateCxxStruct(sd):
2591    """ """
2592    # compute all the typedefs and forward decls we need to make
2593    gettypedeps = _ComputeTypeDeps(sd.decl.type)
2594    for f in sd.fields:
2595        f.ipdltype.accept(gettypedeps)
2596
2597    usingTypedefs = gettypedeps.usingTypedefs
2598    forwarddeclstmts = gettypedeps.forwardDeclStmts
2599    fulldecltypes = gettypedeps.fullDeclTypes
2600
2601    struct = Class(sd.name, final=True)
2602    struct.addstmts([Label.PRIVATE] + usingTypedefs + [Whitespace.NL, Label.PUBLIC])
2603
2604    constreftype = Type(sd.name, const=True, ref=True)
2605
2606    def fieldsAsParamList():
2607        # FIXME Bug 1547019 inType() should do the right thing once
2608        #                   _cxxTypeCanOnlyMove is replaced with
2609        #                   _cxxTypeNeedsMoveForSend
2610        return [
2611            Decl(
2612                f.forceMoveType() if _cxxTypeCanOnlyMove(f.ipdltype) else f.inType(),
2613                f.argVar().name,
2614            )
2615            for f in sd.fields_ipdl_order()
2616        ]
2617
2618    # If this is an empty struct (no fields), then the default ctor
2619    # and "create-with-fields" ctors are equivalent.  So don't bother
2620    # with the default ctor.
2621    if len(sd.fields):
2622        assert len(sd.fields) == len(sd.packed_field_order)
2623
2624        # Struct()
2625        defctor = ConstructorDefn(ConstructorDecl(sd.name, force_inline=True))
2626
2627        # We want to explicitly default-construct every member of the struct.
2628        # This will initialize all primitives which wouldn't be initialized
2629        # normally to their default values, and will initialize any actor member
2630        # pointers to the correct default value of `nullptr`. Other C++ types
2631        # with custom constructors must also provide a default constructor.
2632        defctor.memberinits = [
2633            ExprMemberInit(f.memberVar()) for f in sd.fields_member_order()
2634        ]
2635        struct.addstmts([defctor, Whitespace.NL])
2636
2637    # Struct(const field1& _f1, ...)
2638    valctor = ConstructorDefn(
2639        ConstructorDecl(sd.name, params=fieldsAsParamList(), force_inline=True)
2640    )
2641    valctor.memberinits = []
2642    for f in sd.fields_member_order():
2643        arg = f.argVar()
2644        if _cxxTypeCanOnlyMove(f.ipdltype):
2645            arg = ExprMove(arg)
2646        valctor.memberinits.append(ExprMemberInit(f.memberVar(), args=[arg]))
2647
2648    struct.addstmts([valctor, Whitespace.NL])
2649
2650    # The default copy, move, and assignment constructors, and the default
2651    # destructor, will do the right thing.
2652
2653    if "Comparable" in sd.attributes:
2654        # bool operator==(const Struct& _o)
2655        ovar = ExprVar("_o")
2656        opeqeq = MethodDefn(
2657            MethodDecl(
2658                "operator==",
2659                params=[Decl(constreftype, ovar.name)],
2660                ret=Type.BOOL,
2661                const=True,
2662            )
2663        )
2664        for f in sd.fields_ipdl_order():
2665            ifneq = StmtIf(
2666                ExprNot(
2667                    ExprBinary(
2668                        ExprCall(f.getMethod()), "==", ExprCall(f.getMethod(ovar))
2669                    )
2670                )
2671            )
2672            ifneq.addifstmt(StmtReturn.FALSE)
2673            opeqeq.addstmt(ifneq)
2674        opeqeq.addstmt(StmtReturn.TRUE)
2675        struct.addstmts([opeqeq, Whitespace.NL])
2676
2677        # bool operator!=(const Struct& _o)
2678        opneq = MethodDefn(
2679            MethodDecl(
2680                "operator!=",
2681                params=[Decl(constreftype, ovar.name)],
2682                ret=Type.BOOL,
2683                const=True,
2684            )
2685        )
2686        opneq.addstmt(StmtReturn(ExprNot(ExprCall(ExprVar("operator=="), args=[ovar]))))
2687        struct.addstmts([opneq, Whitespace.NL])
2688
2689    # field1& f1()
2690    # const field1& f1() const
2691    for f in sd.fields_ipdl_order():
2692        get = MethodDefn(
2693            MethodDecl(
2694                f.getMethod().name, params=[], ret=f.refType(), force_inline=True
2695            )
2696        )
2697        get.addstmt(StmtReturn(f.refExpr()))
2698
2699        getconstdecl = deepcopy(get.decl)
2700        getconstdecl.ret = f.constRefType()
2701        getconstdecl.const = True
2702        getconst = MethodDefn(getconstdecl)
2703        getconst.addstmt(StmtReturn(f.constRefExpr()))
2704
2705        struct.addstmts([get, getconst, Whitespace.NL])
2706
2707    # private:
2708    struct.addstmt(Label.PRIVATE)
2709
2710    # Static assertions to ensure our assumptions about field layout match
2711    # what the compiler is actually producing.  We define this as a member
2712    # function, rather than throwing the assertions in the constructor or
2713    # similar, because we don't want to evaluate the static assertions every
2714    # time the header file containing the structure is included.
2715    staticasserts = _fieldStaticAssertions(sd)
2716    if staticasserts:
2717        method = MethodDefn(
2718            MethodDecl("StaticAssertions", params=[], ret=Type.VOID, const=True)
2719        )
2720        method.addstmts(staticasserts)
2721        struct.addstmts([method])
2722
2723    # members
2724    struct.addstmts(
2725        [
2726            StmtDecl(Decl(_effectiveMemberType(f), f.memberVar().name))
2727            for f in sd.fields_member_order()
2728        ]
2729    )
2730
2731    return forwarddeclstmts, fulldecltypes, struct
2732
2733
2734def _effectiveMemberType(f):
2735    effective_type = f.bareType()
2736    # Structs must be copyable for backwards compatibility reasons, so we use
2737    # CopyableTArray<T> as their member type for arrays. This is not exposed
2738    # in the method signatures, these keep using nsTArray<T>, which is a base
2739    # class of CopyableTArray<T>.
2740    if effective_type.name == "nsTArray":
2741        effective_type.name = "CopyableTArray"
2742    return effective_type
2743
2744
2745# --------------------------------------------------
2746
2747
2748def _generateCxxUnion(ud):
2749    # This Union class basically consists of a type (enum) and a
2750    # union for storage.  The union can contain POD and non-POD
2751    # types.  Each type needs a copy/move ctor, assignment operators,
2752    # and dtor.
2753    #
2754    # Rather than templating this class and only providing
2755    # specializations for the types we support, which is slightly
2756    # "unsafe" in that C++ code can add additional specializations
2757    # without the IPDL compiler's knowledge, we instead explicitly
2758    # implement non-templated methods for each supported type.
2759    #
2760    # The one complication that arises is that C++, for arcane
2761    # reasons, does not allow the placement destructor of a
2762    # builtin type, like int, to be directly invoked.  So we need
2763    # to hack around this by internally typedef'ing all
2764    # constituent types.  Sigh.
2765    #
2766    # So, for each type, this "Union" class needs:
2767    # (private)
2768    #  - entry in the type enum
2769    #  - entry in the storage union
2770    #  - [type]ptr() method to get a type* from the underlying union
2771    #  - same as above to get a const type*
2772    #  - typedef to hack around placement delete limitations
2773    # (public)
2774    #  - placement delete case for dtor
2775    #  - copy ctor
2776    #  - move ctor
2777    #  - case in generic copy ctor
2778    #  - copy operator= impl
2779    #  - move operator= impl
2780    #  - case in generic operator=
2781    #  - operator [type&]
2782    #  - operator [const type&] const
2783    #  - [type&] get_[type]()
2784    #  - [const type&] get_[type]() const
2785    #
2786    cls = Class(ud.name, final=True)
2787    # const Union&, i.e., Union type with inparam semantics
2788    inClsType = Type(ud.name, const=True, ref=True)
2789    refClsType = Type(ud.name, ref=True)
2790    rvalueRefClsType = Type(ud.name, rvalref=True)
2791    typetype = Type("Type")
2792    valuetype = Type("Value")
2793    mtypevar = ExprVar("mType")
2794    mvaluevar = ExprVar("mValue")
2795    maybedtorvar = ExprVar("MaybeDestroy")
2796    assertsanityvar = ExprVar("AssertSanity")
2797    tnonevar = ExprVar("T__None")
2798    tlastvar = ExprVar("T__Last")
2799
2800    def callAssertSanity(uvar=None, expectTypeVar=None):
2801        func = assertsanityvar
2802        args = []
2803        if uvar is not None:
2804            func = ExprSelect(uvar, ".", assertsanityvar.name)
2805        if expectTypeVar is not None:
2806            args.append(expectTypeVar)
2807        return ExprCall(func, args=args)
2808
2809    def callMaybeDestroy(newTypeVar):
2810        return ExprCall(maybedtorvar, args=[newTypeVar])
2811
2812    def maybeReconstruct(memb, newTypeVar):
2813        ifdied = StmtIf(callMaybeDestroy(newTypeVar))
2814        ifdied.addifstmt(StmtExpr(memb.callCtor()))
2815        return ifdied
2816
2817    def voidCast(expr):
2818        return ExprCast(expr, Type.VOID, static=True)
2819
2820    # compute all the typedefs and forward decls we need to make
2821    gettypedeps = _ComputeTypeDeps(ud.decl.type)
2822    for c in ud.components:
2823        c.ipdltype.accept(gettypedeps)
2824
2825    usingTypedefs = gettypedeps.usingTypedefs
2826    forwarddeclstmts = gettypedeps.forwardDeclStmts
2827    fulldecltypes = gettypedeps.fullDeclTypes
2828
2829    # the |Type| enum, used to switch on the discunion's real type
2830    cls.addstmt(Label.PUBLIC)
2831    typeenum = TypeEnum(typetype.name)
2832    typeenum.addId(tnonevar.name, 0)
2833    firstid = ud.components[0].enum()
2834    typeenum.addId(firstid, 1)
2835    for c in ud.components[1:]:
2836        typeenum.addId(c.enum())
2837    typeenum.addId(tlastvar.name, ud.components[-1].enum())
2838    cls.addstmts([StmtDecl(Decl(typeenum, "")), Whitespace.NL])
2839
2840    cls.addstmt(Label.PRIVATE)
2841    cls.addstmts(
2842        usingTypedefs
2843        # hacky typedef's that allow placement dtors of builtins
2844        + [Typedef(c.internalType(), c.typedef()) for c in ud.components]
2845    )
2846    cls.addstmt(Whitespace.NL)
2847
2848    # the C++ union the discunion use for storage
2849    valueunion = TypeUnion(valuetype.name)
2850    for c in ud.components:
2851        valueunion.addComponent(c.unionType(), c.name)
2852    cls.addstmts([StmtDecl(Decl(valueunion, "")), Whitespace.NL])
2853
2854    # for each constituent type T, add private accessors that
2855    # return a pointer to the Value union storage casted to |T*|
2856    # and |const T*|
2857    for c in ud.components:
2858        getptr = MethodDefn(
2859            MethodDecl(
2860                c.getPtrName(), params=[], ret=c.ptrToInternalType(), force_inline=True
2861            )
2862        )
2863        getptr.addstmt(StmtReturn(c.ptrToSelfExpr()))
2864
2865        getptrconst = MethodDefn(
2866            MethodDecl(
2867                c.getConstPtrName(),
2868                params=[],
2869                ret=c.constPtrToType(),
2870                const=True,
2871                force_inline=True,
2872            )
2873        )
2874        getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr()))
2875
2876        cls.addstmts([getptr, getptrconst])
2877    cls.addstmt(Whitespace.NL)
2878
2879    # add a helper method that invokes the placement dtor on the
2880    # current underlying value, only if |aNewType| is different
2881    # than the current type, and returns true if the underlying
2882    # value needs to be re-constructed
2883    newtypevar = ExprVar("aNewType")
2884    maybedtor = MethodDefn(
2885        MethodDecl(
2886            maybedtorvar.name, params=[Decl(typetype, newtypevar.name)], ret=Type.BOOL
2887        )
2888    )
2889    # wasn't /actually/ dtor'd, but it needs to be re-constructed
2890    ifnone = StmtIf(ExprBinary(mtypevar, "==", tnonevar))
2891    ifnone.addifstmt(StmtReturn.TRUE)
2892    # same type, nothing to see here
2893    ifnochange = StmtIf(ExprBinary(mtypevar, "==", newtypevar))
2894    ifnochange.addifstmt(StmtReturn.FALSE)
2895    # need to destroy.  switch on underlying type
2896    dtorswitch = StmtSwitch(mtypevar)
2897    for c in ud.components:
2898        dtorswitch.addcase(
2899            CaseLabel(c.enum()), StmtBlock([StmtExpr(c.callDtor()), StmtBreak()])
2900        )
2901    dtorswitch.addcase(
2902        DefaultLabel(), StmtBlock([_logicError("not reached"), StmtBreak()])
2903    )
2904    maybedtor.addstmts([ifnone, ifnochange, dtorswitch, StmtReturn.TRUE])
2905    cls.addstmts([maybedtor, Whitespace.NL])
2906
2907    # add helper methods that ensure the discunion has a
2908    # valid type
2909    sanity = MethodDefn(
2910        MethodDecl(assertsanityvar.name, ret=Type.VOID, const=True, force_inline=True)
2911    )
2912    sanity.addstmts(
2913        [
2914            _abortIfFalse(ExprBinary(tnonevar, "<=", mtypevar), "invalid type tag"),
2915            _abortIfFalse(ExprBinary(mtypevar, "<=", tlastvar), "invalid type tag"),
2916        ]
2917    )
2918    cls.addstmt(sanity)
2919
2920    atypevar = ExprVar("aType")
2921    sanity2 = MethodDefn(
2922        MethodDecl(
2923            assertsanityvar.name,
2924            params=[Decl(typetype, atypevar.name)],
2925            ret=Type.VOID,
2926            const=True,
2927            force_inline=True,
2928        )
2929    )
2930    sanity2.addstmts(
2931        [
2932            StmtExpr(ExprCall(assertsanityvar)),
2933            _abortIfFalse(ExprBinary(mtypevar, "==", atypevar), "unexpected type tag"),
2934        ]
2935    )
2936    cls.addstmts([sanity2, Whitespace.NL])
2937
2938    # ---- begin public methods -----
2939
2940    # Union() default ctor
2941    cls.addstmts(
2942        [
2943            Label.PUBLIC,
2944            ConstructorDefn(
2945                ConstructorDecl(ud.name, force_inline=True),
2946                memberinits=[ExprMemberInit(mtypevar, [tnonevar])],
2947            ),
2948            Whitespace.NL,
2949        ]
2950    )
2951
2952    # Union(const T&) copy & Union(T&&) move ctors
2953    othervar = ExprVar("aOther")
2954    for c in ud.components:
2955        if not _cxxTypeCanOnlyMove(c.ipdltype):
2956            copyctor = ConstructorDefn(
2957                ConstructorDecl(ud.name, params=[Decl(c.inType(), othervar.name)])
2958            )
2959            copyctor.addstmts(
2960                [
2961                    StmtExpr(c.callCtor(othervar)),
2962                    StmtExpr(ExprAssn(mtypevar, c.enumvar())),
2963                ]
2964            )
2965            cls.addstmts([copyctor, Whitespace.NL])
2966
2967        if not _cxxTypeCanMove(c.ipdltype) or _cxxTypeNeedsMoveForSend(c.ipdltype):
2968            continue
2969        movector = ConstructorDefn(
2970            ConstructorDecl(ud.name, params=[Decl(c.forceMoveType(), othervar.name)])
2971        )
2972        movector.addstmts(
2973            [
2974                StmtExpr(c.callCtor(ExprMove(othervar))),
2975                StmtExpr(ExprAssn(mtypevar, c.enumvar())),
2976            ]
2977        )
2978        cls.addstmts([movector, Whitespace.NL])
2979
2980    unionNeedsMove = any(_cxxTypeCanOnlyMove(c.ipdltype) for c in ud.components)
2981
2982    # Union(const Union&) copy ctor
2983    if not unionNeedsMove:
2984        copyctor = ConstructorDefn(
2985            ConstructorDecl(ud.name, params=[Decl(inClsType, othervar.name)])
2986        )
2987        othertype = ud.callType(othervar)
2988        copyswitch = StmtSwitch(othertype)
2989        for c in ud.components:
2990            copyswitch.addcase(
2991                CaseLabel(c.enum()),
2992                StmtBlock(
2993                    [
2994                        StmtExpr(
2995                            c.callCtor(
2996                                ExprCall(
2997                                    ExprSelect(othervar, ".", c.getConstTypeName())
2998                                )
2999                            )
3000                        ),
3001                        StmtBreak(),
3002                    ]
3003                ),
3004            )
3005        copyswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
3006        copyswitch.addcase(
3007            DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()])
3008        )
3009        copyctor.addstmts(
3010            [
3011                StmtExpr(callAssertSanity(uvar=othervar)),
3012                copyswitch,
3013                StmtExpr(ExprAssn(mtypevar, othertype)),
3014            ]
3015        )
3016        cls.addstmts([copyctor, Whitespace.NL])
3017
3018    # Union(Union&&) move ctor
3019    movector = ConstructorDefn(
3020        ConstructorDecl(ud.name, params=[Decl(rvalueRefClsType, othervar.name)])
3021    )
3022    othertypevar = ExprVar("t")
3023    moveswitch = StmtSwitch(othertypevar)
3024    for c in ud.components:
3025        case = StmtBlock()
3026        if c.recursive:
3027            # This is sound as we set othervar.mTypeVar to T__None after the
3028            # switch. The pointer in the union will be left dangling.
3029            case.addstmts(
3030                [
3031                    # ptr_C() = other.ptr_C()
3032                    StmtExpr(
3033                        ExprAssn(
3034                            c.callGetPtr(),
3035                            ExprCall(
3036                                ExprSelect(othervar, ".", ExprVar(c.getPtrName()))
3037                            ),
3038                        )
3039                    )
3040                ]
3041            )
3042        else:
3043            case.addstmts(
3044                [
3045                    # new ... (Move(other.get_C()))
3046                    StmtExpr(
3047                        c.callCtor(
3048                            ExprMove(
3049                                ExprCall(ExprSelect(othervar, ".", c.getTypeName()))
3050                            )
3051                        )
3052                    ),
3053                    # other.MaybeDestroy(T__None)
3054                    StmtExpr(
3055                        voidCast(
3056                            ExprCall(
3057                                ExprSelect(othervar, ".", maybedtorvar), args=[tnonevar]
3058                            )
3059                        )
3060                    ),
3061                ]
3062            )
3063        case.addstmts([StmtBreak()])
3064        moveswitch.addcase(CaseLabel(c.enum()), case)
3065    moveswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
3066    moveswitch.addcase(
3067        DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()])
3068    )
3069    movector.addstmts(
3070        [
3071            StmtExpr(callAssertSanity(uvar=othervar)),
3072            StmtDecl(Decl(typetype, othertypevar.name), init=ud.callType(othervar)),
3073            moveswitch,
3074            StmtExpr(ExprAssn(ExprSelect(othervar, ".", mtypevar), tnonevar)),
3075            StmtExpr(ExprAssn(mtypevar, othertypevar)),
3076        ]
3077    )
3078    cls.addstmts([movector, Whitespace.NL])
3079
3080    # ~Union()
3081    dtor = DestructorDefn(DestructorDecl(ud.name))
3082    # The void cast prevents Coverity from complaining about missing return
3083    # value checks.
3084    dtor.addstmt(StmtExpr(voidCast(callMaybeDestroy(tnonevar))))
3085    cls.addstmts([dtor, Whitespace.NL])
3086
3087    # type()
3088    typemeth = MethodDefn(
3089        MethodDecl("type", ret=typetype, const=True, force_inline=True)
3090    )
3091    typemeth.addstmt(StmtReturn(mtypevar))
3092    cls.addstmts([typemeth, Whitespace.NL])
3093
3094    # Union& operator= methods
3095    rhsvar = ExprVar("aRhs")
3096    for c in ud.components:
3097        if not _cxxTypeCanOnlyMove(c.ipdltype):
3098            # Union& operator=(const T&)
3099            opeq = MethodDefn(
3100                MethodDecl(
3101                    "operator=", params=[Decl(c.inType(), rhsvar.name)], ret=refClsType
3102                )
3103            )
3104            opeq.addstmts(
3105                [
3106                    # might need to placement-delete old value first
3107                    maybeReconstruct(c, c.enumvar()),
3108                    StmtExpr(c.callOperatorEq(rhsvar)),
3109                    StmtExpr(ExprAssn(mtypevar, c.enumvar())),
3110                    StmtReturn(ExprDeref(ExprVar.THIS)),
3111                ]
3112            )
3113            cls.addstmts([opeq, Whitespace.NL])
3114
3115        # Union& operator=(T&&)
3116        if not _cxxTypeCanMove(c.ipdltype) or _cxxTypeNeedsMoveForSend(c.ipdltype):
3117            continue
3118
3119        opeq = MethodDefn(
3120            MethodDecl(
3121                "operator=",
3122                params=[Decl(c.forceMoveType(), rhsvar.name)],
3123                ret=refClsType,
3124            )
3125        )
3126        opeq.addstmts(
3127            [
3128                # might need to placement-delete old value first
3129                maybeReconstruct(c, c.enumvar()),
3130                StmtExpr(c.callOperatorEq(ExprMove(rhsvar))),
3131                StmtExpr(ExprAssn(mtypevar, c.enumvar())),
3132                StmtReturn(ExprDeref(ExprVar.THIS)),
3133            ]
3134        )
3135        cls.addstmts([opeq, Whitespace.NL])
3136
3137    # Union& operator=(const Union&)
3138    if not unionNeedsMove:
3139        opeq = MethodDefn(
3140            MethodDecl(
3141                "operator=", params=[Decl(inClsType, rhsvar.name)], ret=refClsType
3142            )
3143        )
3144        rhstypevar = ExprVar("t")
3145        opeqswitch = StmtSwitch(rhstypevar)
3146        for c in ud.components:
3147            case = StmtBlock()
3148            case.addstmts(
3149                [
3150                    maybeReconstruct(c, rhstypevar),
3151                    StmtExpr(
3152                        c.callOperatorEq(
3153                            ExprCall(ExprSelect(rhsvar, ".", c.getConstTypeName()))
3154                        )
3155                    ),
3156                    StmtBreak(),
3157                ]
3158            )
3159            opeqswitch.addcase(CaseLabel(c.enum()), case)
3160        opeqswitch.addcase(
3161            CaseLabel(tnonevar.name),
3162            # The void cast prevents Coverity from complaining about missing return
3163            # value checks.
3164            StmtBlock(
3165                [
3166                    StmtExpr(
3167                        ExprCast(callMaybeDestroy(rhstypevar), Type.VOID, static=True)
3168                    ),
3169                    StmtBreak(),
3170                ]
3171            ),
3172        )
3173        opeqswitch.addcase(
3174            DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()])
3175        )
3176        opeq.addstmts(
3177            [
3178                StmtExpr(callAssertSanity(uvar=rhsvar)),
3179                StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
3180                opeqswitch,
3181                StmtExpr(ExprAssn(mtypevar, rhstypevar)),
3182                StmtReturn(ExprDeref(ExprVar.THIS)),
3183            ]
3184        )
3185        cls.addstmts([opeq, Whitespace.NL])
3186
3187    # Union& operator=(Union&&)
3188    opeq = MethodDefn(
3189        MethodDecl(
3190            "operator=", params=[Decl(rvalueRefClsType, rhsvar.name)], ret=refClsType
3191        )
3192    )
3193    rhstypevar = ExprVar("t")
3194    opeqswitch = StmtSwitch(rhstypevar)
3195    for c in ud.components:
3196        case = StmtBlock()
3197        if c.recursive:
3198            case.addstmts(
3199                [
3200                    StmtExpr(voidCast(callMaybeDestroy(tnonevar))),
3201                    StmtExpr(
3202                        ExprAssn(
3203                            c.callGetPtr(),
3204                            ExprCall(ExprSelect(rhsvar, ".", ExprVar(c.getPtrName()))),
3205                        )
3206                    ),
3207                ]
3208            )
3209        else:
3210            case.addstmts(
3211                [
3212                    maybeReconstruct(c, rhstypevar),
3213                    StmtExpr(
3214                        c.callOperatorEq(
3215                            ExprMove(ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())))
3216                        )
3217                    ),
3218                    # other.MaybeDestroy(T__None)
3219                    StmtExpr(
3220                        voidCast(
3221                            ExprCall(
3222                                ExprSelect(rhsvar, ".", maybedtorvar), args=[tnonevar]
3223                            )
3224                        )
3225                    ),
3226                ]
3227            )
3228        case.addstmts([StmtBreak()])
3229        opeqswitch.addcase(CaseLabel(c.enum()), case)
3230    opeqswitch.addcase(
3231        CaseLabel(tnonevar.name),
3232        # The void cast prevents Coverity from complaining about missing return
3233        # value checks.
3234        StmtBlock([StmtExpr(voidCast(callMaybeDestroy(rhstypevar))), StmtBreak()]),
3235    )
3236    opeqswitch.addcase(
3237        DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()])
3238    )
3239    opeq.addstmts(
3240        [
3241            StmtExpr(callAssertSanity(uvar=rhsvar)),
3242            StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
3243            opeqswitch,
3244            StmtExpr(ExprAssn(ExprSelect(rhsvar, ".", mtypevar), tnonevar)),
3245            StmtExpr(ExprAssn(mtypevar, rhstypevar)),
3246            StmtReturn(ExprDeref(ExprVar.THIS)),
3247        ]
3248    )
3249    cls.addstmts([opeq, Whitespace.NL])
3250
3251    if "Comparable" in ud.attributes:
3252        # bool operator==(const T&)
3253        for c in ud.components:
3254            opeqeq = MethodDefn(
3255                MethodDecl(
3256                    "operator==",
3257                    params=[Decl(c.inType(), rhsvar.name)],
3258                    ret=Type.BOOL,
3259                    const=True,
3260                )
3261            )
3262            opeqeq.addstmt(
3263                StmtReturn(ExprBinary(ExprCall(ExprVar(c.getTypeName())), "==", rhsvar))
3264            )
3265            cls.addstmts([opeqeq, Whitespace.NL])
3266
3267        # bool operator==(const Union&)
3268        opeqeq = MethodDefn(
3269            MethodDecl(
3270                "operator==",
3271                params=[Decl(inClsType, rhsvar.name)],
3272                ret=Type.BOOL,
3273                const=True,
3274            )
3275        )
3276        iftypesmismatch = StmtIf(ExprBinary(ud.callType(), "!=", ud.callType(rhsvar)))
3277        iftypesmismatch.addifstmt(StmtReturn.FALSE)
3278        opeqeq.addstmts([iftypesmismatch, Whitespace.NL])
3279
3280        opeqeqswitch = StmtSwitch(ud.callType())
3281        for c in ud.components:
3282            case = StmtBlock()
3283            case.addstmt(
3284                StmtReturn(
3285                    ExprBinary(
3286                        ExprCall(ExprVar(c.getTypeName())),
3287                        "==",
3288                        ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())),
3289                    )
3290                )
3291            )
3292            opeqeqswitch.addcase(CaseLabel(c.enum()), case)
3293        opeqeqswitch.addcase(
3294            DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn.FALSE])
3295        )
3296        opeqeq.addstmt(opeqeqswitch)
3297
3298        cls.addstmts([opeqeq, Whitespace.NL])
3299
3300    # accessors for each type: operator T&, operator const T&,
3301    # T& get(), const T& get()
3302    for c in ud.components:
3303        getValueVar = ExprVar(c.getTypeName())
3304        getConstValueVar = ExprVar(c.getConstTypeName())
3305
3306        getvalue = MethodDefn(
3307            MethodDecl(getValueVar.name, ret=c.refType(), force_inline=True)
3308        )
3309        getvalue.addstmts(
3310            [
3311                StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
3312                StmtReturn(ExprDeref(c.callGetPtr())),
3313            ]
3314        )
3315
3316        getconstvalue = MethodDefn(
3317            MethodDecl(
3318                getConstValueVar.name,
3319                ret=c.constRefType(),
3320                const=True,
3321                force_inline=True,
3322            )
3323        )
3324        getconstvalue.addstmts(
3325            [
3326                StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
3327                StmtReturn(c.getConstValue()),
3328            ]
3329        )
3330
3331        cls.addstmts([getvalue, getconstvalue])
3332
3333        optype = MethodDefn(MethodDecl("", typeop=c.refType(), force_inline=True))
3334        optype.addstmt(StmtReturn(ExprCall(getValueVar)))
3335        opconsttype = MethodDefn(
3336            MethodDecl("", const=True, typeop=c.constRefType(), force_inline=True)
3337        )
3338        opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar)))
3339
3340        cls.addstmts([optype, opconsttype, Whitespace.NL])
3341    # private vars
3342    cls.addstmts(
3343        [
3344            Label.PRIVATE,
3345            StmtDecl(Decl(valuetype, mvaluevar.name)),
3346            StmtDecl(Decl(typetype, mtypevar.name)),
3347        ]
3348    )
3349
3350    return forwarddeclstmts, fulldecltypes, cls
3351
3352
3353# -----------------------------------------------------------------------------
3354
3355
3356class _FindFriends(ipdl.ast.Visitor):
3357    def __init__(self):
3358        self.mytype = None  # ProtocolType
3359        self.vtype = None  # ProtocolType
3360        self.friends = set()  # set<ProtocolType>
3361
3362    def findFriends(self, ptype):
3363        self.mytype = ptype
3364        for toplvl in ptype.toplevels():
3365            self.walkDownTheProtocolTree(toplvl)
3366        return self.friends
3367
3368    # TODO could make this into a _iterProtocolTreeHelper ...
3369    def walkDownTheProtocolTree(self, ptype):
3370        if ptype != self.mytype:
3371            # don't want to |friend| ourself!
3372            self.visit(ptype)
3373        for mtype in ptype.manages:
3374            if mtype is not ptype:
3375                self.walkDownTheProtocolTree(mtype)
3376
3377    def visit(self, ptype):
3378        # |vtype| is the type currently being visited
3379        savedptype = self.vtype
3380        self.vtype = ptype
3381        ptype._ast.accept(self)
3382        self.vtype = savedptype
3383
3384    def visitMessageDecl(self, md):
3385        for it in self.iterActorParams(md):
3386            if it.protocol == self.mytype:
3387                self.friends.add(self.vtype)
3388
3389    def iterActorParams(self, md):
3390        for param in md.inParams:
3391            for actor in ipdl.type.iteractortypes(param.type):
3392                yield actor
3393        for ret in md.outParams:
3394            for actor in ipdl.type.iteractortypes(ret.type):
3395                yield actor
3396
3397
3398class _GenerateProtocolActorCode(ipdl.ast.Visitor):
3399    def __init__(self, myside):
3400        self.side = myside  # "parent" or "child"
3401        self.prettyside = myside.title()
3402        self.clsname = None
3403        self.protocol = None
3404        self.hdrfile = None
3405        self.cppfile = None
3406        self.ns = None
3407        self.cls = None
3408        self.includedActorTypedefs = []
3409        self.protocolCxxIncludes = []
3410        self.actorForwardDecls = []
3411        self.usingDecls = []
3412        self.externalIncludes = set()
3413        self.nonForwardDeclaredHeaders = set()
3414
3415    def lower(self, tu, clsname, cxxHeaderFile, cxxFile):
3416        self.clsname = clsname
3417        self.hdrfile = cxxHeaderFile
3418        self.cppfile = cxxFile
3419        tu.accept(self)
3420
3421    def standardTypedefs(self):
3422        return [
3423            Typedef(Type("mozilla::ipc::IProtocol"), "IProtocol"),
3424            Typedef(Type("IPC::Message"), "Message"),
3425            Typedef(Type("base::ProcessHandle"), "ProcessHandle"),
3426            Typedef(Type("mozilla::ipc::MessageChannel"), "MessageChannel"),
3427            Typedef(Type("mozilla::ipc::SharedMemory"), "SharedMemory"),
3428        ]
3429
3430    def visitTranslationUnit(self, tu):
3431        self.protocol = tu.protocol
3432
3433        hf = self.hdrfile
3434        cf = self.cppfile
3435
3436        # make the C++ header
3437        hf.addthings(
3438            [_DISCLAIMER]
3439            + _includeGuardStart(hf)
3440            + [
3441                Whitespace.NL,
3442                CppDirective("include", '"' + _protocolHeaderName(tu.protocol) + '.h"'),
3443            ]
3444        )
3445
3446        for inc in tu.includes:
3447            inc.accept(self)
3448        for inc in tu.cxxIncludes:
3449            inc.accept(self)
3450
3451        for using in tu.using:
3452            using.accept(self)
3453
3454        # this generates the actor's full impl in self.cls
3455        tu.protocol.accept(self)
3456
3457        clsdecl, clsdefn = _splitClassDeclDefn(self.cls)
3458
3459        # XXX damn C++ ... return types in the method defn aren't in
3460        # class scope
3461        for stmt in clsdefn.stmts:
3462            if isinstance(stmt, MethodDefn):
3463                if stmt.decl.ret and stmt.decl.ret.name == "Result":
3464                    stmt.decl.ret.name = clsdecl.name + "::" + stmt.decl.ret.name
3465
3466        def setToIncludes(s):
3467            return [CppDirective("include", '"%s"' % i) for i in sorted(iter(s))]
3468
3469        def makeNamespace(p, file):
3470            if 0 == len(p.namespaces):
3471                return file
3472            ns = Namespace(p.namespaces[-1].name)
3473            outerns = _putInNamespaces(ns, p.namespaces[:-1])
3474            file.addthing(outerns)
3475            return ns
3476
3477        if len(self.nonForwardDeclaredHeaders) != 0:
3478            self.hdrfile.addthings(
3479                [
3480                    Whitespace("// Headers for things that cannot be forward declared"),
3481                    Whitespace.NL,
3482                ]
3483                + setToIncludes(self.nonForwardDeclaredHeaders)
3484                + [Whitespace.NL]
3485            )
3486        self.hdrfile.addthings(self.actorForwardDecls)
3487        self.hdrfile.addthings(self.usingDecls)
3488
3489        hdrns = makeNamespace(self.protocol, self.hdrfile)
3490        hdrns.addstmts(
3491            [Whitespace.NL, Whitespace.NL, clsdecl, Whitespace.NL, Whitespace.NL]
3492        )
3493
3494        actortype = ActorType(tu.protocol.decl.type)
3495        traitsdecl, traitsdefn = _ParamTraits.actorPickling(actortype, self.side)
3496
3497        self.hdrfile.addthings([traitsdecl, Whitespace.NL] + _includeGuardEnd(hf))
3498
3499        # make the .cpp file
3500        if (self.protocol.name, self.side) not in VIRTUAL_CALL_CLASSES:
3501            if (self.protocol.name, self.side) in DIRECT_CALL_OVERRIDES:
3502                (_, header_file) = DIRECT_CALL_OVERRIDES[self.protocol.name, self.side]
3503            else:
3504                assert self.protocol.name.startswith("P")
3505                header_file = "{}/{}{}.h".format(
3506                    "/".join(n.name for n in self.protocol.namespaces),
3507                    self.protocol.name[1:],
3508                    self.side.capitalize(),
3509                )
3510            self.externalIncludes.add(header_file)
3511
3512        cf.addthings(
3513            [
3514                _DISCLAIMER,
3515                Whitespace.NL,
3516                CppDirective(
3517                    "include",
3518                    '"' + _protocolHeaderName(self.protocol, self.side) + '.h"',
3519                ),
3520            ]
3521            + setToIncludes(self.externalIncludes)
3522        )
3523
3524        cf.addthings(
3525            (
3526                [Whitespace.NL]
3527                + [
3528                    CppDirective("include", '"%s.h"' % (inc))
3529                    for inc in self.protocolCxxIncludes
3530                ]
3531                + [Whitespace.NL]
3532                + [
3533                    CppDirective("include", '"%s"' % filename)
3534                    for filename in ipdl.builtin.CppIncludes
3535                ]
3536                + [Whitespace.NL]
3537            )
3538        )
3539
3540        cppns = makeNamespace(self.protocol, cf)
3541        cppns.addstmts(
3542            [Whitespace.NL, Whitespace.NL, clsdefn, Whitespace.NL, Whitespace.NL]
3543        )
3544
3545        cf.addthing(traitsdefn)
3546
3547    def visitUsingStmt(self, using):
3548        if using.header is None:
3549            return
3550
3551        if using.canBeForwardDeclared() and not using.decl.type.isUniquePtr():
3552            spec = using.type.spec
3553
3554            self.usingDecls.extend(
3555                [
3556                    _makeForwardDeclForQClass(
3557                        spec.baseid,
3558                        spec.quals,
3559                        cls=using.isClass(),
3560                        struct=using.isStruct(),
3561                    ),
3562                    Whitespace.NL,
3563                ]
3564            )
3565            self.externalIncludes.add(using.header)
3566        else:
3567            self.nonForwardDeclaredHeaders.add(using.header)
3568
3569    def visitCxxInclude(self, inc):
3570        self.externalIncludes.add(inc.file)
3571
3572    def visitInclude(self, inc):
3573        ip = inc.tu.protocol
3574        if not ip:
3575            return
3576
3577        self.actorForwardDecls.extend(
3578            [
3579                _makeForwardDeclForActor(ip.decl.type, self.side),
3580                _makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)),
3581                Whitespace.NL,
3582            ]
3583        )
3584        self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side))
3585
3586        if ip.decl.fullname is not None:
3587            self.includedActorTypedefs.append(
3588                Typedef(
3589                    Type(_actorName(ip.decl.fullname, self.side.title())),
3590                    _actorName(ip.decl.shortname, self.side.title()),
3591                )
3592            )
3593
3594            self.includedActorTypedefs.append(
3595                Typedef(
3596                    Type(_actorName(ip.decl.fullname, _otherSide(self.side).title())),
3597                    _actorName(ip.decl.shortname, _otherSide(self.side).title()),
3598                )
3599            )
3600
3601    def visitProtocol(self, p):
3602        self.hdrfile.addcode(
3603            """
3604            #ifdef DEBUG
3605            #include "prenv.h"
3606            #endif  // DEBUG
3607
3608            #include "mozilla/Tainting.h"
3609            #include "mozilla/ipc/MessageChannel.h"
3610            #include "mozilla/ipc/ProtocolUtils.h"
3611            """
3612        )
3613
3614        self.protocol = p
3615        ptype = p.decl.type
3616        toplevel = p.decl.type.toplevel()
3617
3618        hasAsyncReturns = False
3619        for md in p.messageDecls:
3620            if md.hasAsyncReturns():
3621                hasAsyncReturns = True
3622                break
3623
3624        inherits = []
3625        if ptype.isToplevel():
3626            inherits.append(Inherit(p.openedProtocolInterfaceType(), viz="public"))
3627        else:
3628            inherits.append(Inherit(p.managerInterfaceType(), viz="public"))
3629
3630        if ptype.isToplevel() and self.side == "parent":
3631            self.hdrfile.addthings(
3632                [_makeForwardDeclForQClass("nsIFile", []), Whitespace.NL]
3633            )
3634
3635        self.cls = Class(self.clsname, inherits=inherits, abstract=True)
3636
3637        self.cls.addstmt(Label.PRIVATE)
3638        friends = _FindFriends().findFriends(ptype)
3639        if ptype.isManaged():
3640            friends.update(ptype.managers)
3641
3642        # |friend| managed actors so that they can call our Dealloc*()
3643        friends.update(ptype.manages)
3644
3645        # don't friend ourself if we're a self-managed protocol
3646        friends.discard(ptype)
3647
3648        for friend in sorted(friends, key=lambda f: f.fullname()):
3649            self.actorForwardDecls.extend(
3650                [_makeForwardDeclForActor(friend, self.prettyside), Whitespace.NL]
3651            )
3652            self.cls.addstmt(
3653                FriendClassDecl(_actorName(friend.fullname(), self.prettyside))
3654            )
3655
3656        self.cls.addstmt(Label.PROTECTED)
3657        for typedef in p.cxxTypedefs():
3658            self.cls.addstmt(typedef)
3659        for typedef in self.includedActorTypedefs:
3660            self.cls.addstmt(typedef)
3661
3662        self.cls.addstmt(Whitespace.NL)
3663
3664        if hasAsyncReturns:
3665            self.cls.addstmt(Label.PUBLIC)
3666            for md in p.messageDecls:
3667                if self.sendsMessage(md) and md.hasAsyncReturns():
3668                    self.cls.addstmt(
3669                        Typedef(_makePromise(md.returns, self.side), md.promiseName())
3670                    )
3671                if self.receivesMessage(md) and md.hasAsyncReturns():
3672                    self.cls.addstmt(
3673                        Typedef(_makeResolver(md.returns, self.side), md.resolverName())
3674                    )
3675            self.cls.addstmt(Whitespace.NL)
3676
3677        self.cls.addstmt(Label.PROTECTED)
3678        # interface methods that the concrete subclass has to impl
3679        for md in p.messageDecls:
3680            isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor()
3681
3682            if self.receivesMessage(md):
3683                # generate Recv/Answer* interface
3684                implicit = not isdtor
3685                returnsems = "resolver" if md.decl.type.isAsync() else "out"
3686                recvDecl = MethodDecl(
3687                    md.recvMethod(),
3688                    params=md.makeCxxParams(
3689                        paramsems="move",
3690                        returnsems=returnsems,
3691                        side=self.side,
3692                        implicit=implicit,
3693                        direction="recv",
3694                    ),
3695                    ret=Type("mozilla::ipc::IPCResult"),
3696                    methodspec=MethodSpec.VIRTUAL,
3697                )
3698
3699                # These method implementations cause problems when trying to
3700                # override them with different types in a direct call class.
3701                #
3702                # For the `isdtor` case there's a simple solution: it doesn't
3703                # make much sense to specify arguments and then completely
3704                # ignore them, and the no-arg case isn't a problem for
3705                # overriding.
3706                if isctor or (isdtor and not md.inParams):
3707                    defaultRecv = MethodDefn(recvDecl)
3708                    defaultRecv.addcode("return IPC_OK();\n")
3709                    self.cls.addstmt(defaultRecv)
3710                elif (self.protocol.name, self.side) in VIRTUAL_CALL_CLASSES:
3711                    # If we're using virtual calls, we need the methods to be
3712                    # declared on the base class.
3713                    recvDecl.methodspec = MethodSpec.PURE
3714                    self.cls.addstmt(StmtDecl(recvDecl))
3715
3716        # If we're using virtual calls, we need the methods to be declared on
3717        # the base class.
3718        if (self.protocol.name, self.side) in VIRTUAL_CALL_CLASSES:
3719            for md in p.messageDecls:
3720                managed = md.decl.type.constructedType()
3721                if not ptype.isManagerOf(managed) or md.decl.type.isDtor():
3722                    continue
3723
3724                # add the Alloc interface for managed actors
3725                actortype = md.actorDecl().bareType(self.side)
3726
3727                if managed.isRefcounted():
3728                    if not self.receivesMessage(md):
3729                        continue
3730
3731                    actortype.ptr = False
3732                    actortype = _alreadyaddrefed(actortype)
3733
3734                self.cls.addstmt(
3735                    StmtDecl(
3736                        MethodDecl(
3737                            _allocMethod(managed, self.side),
3738                            params=md.makeCxxParams(
3739                                side=self.side, implicit=False, direction="recv"
3740                            ),
3741                            ret=actortype,
3742                            methodspec=MethodSpec.PURE,
3743                        )
3744                    )
3745                )
3746
3747            # add the Dealloc interface for all managed non-refcounted actors,
3748            # even without ctors. This is useful for protocols which use
3749            # ManagedEndpoint for construction.
3750            for managed in ptype.manages:
3751                if managed.isRefcounted():
3752                    continue
3753
3754                self.cls.addstmt(
3755                    StmtDecl(
3756                        MethodDecl(
3757                            _deallocMethod(managed, self.side),
3758                            params=[
3759                                Decl(p.managedCxxType(managed, self.side), "aActor")
3760                            ],
3761                            ret=Type.BOOL,
3762                            methodspec=MethodSpec.PURE,
3763                        )
3764                    )
3765                )
3766
3767        if ptype.isToplevel():
3768            # void ProcessingError(code); default to no-op
3769            processingerror = MethodDefn(
3770                MethodDecl(
3771                    p.processingErrorVar().name,
3772                    params=[
3773                        Param(_Result.Type(), "aCode"),
3774                        Param(Type("char", const=True, ptr=True), "aReason"),
3775                    ],
3776                    methodspec=MethodSpec.OVERRIDE,
3777                )
3778            )
3779
3780            # bool ShouldContinueFromReplyTimeout(); default to |true|
3781            shouldcontinue = MethodDefn(
3782                MethodDecl(
3783                    p.shouldContinueFromTimeoutVar().name,
3784                    ret=Type.BOOL,
3785                    methodspec=MethodSpec.OVERRIDE,
3786                )
3787            )
3788            shouldcontinue.addcode("return true;\n")
3789
3790            # void Entered*()/Exited*(); default to no-op
3791            entered = MethodDefn(
3792                MethodDecl(p.enteredCxxStackVar().name, methodspec=MethodSpec.OVERRIDE)
3793            )
3794            exited = MethodDefn(
3795                MethodDecl(p.exitedCxxStackVar().name, methodspec=MethodSpec.OVERRIDE)
3796            )
3797            enteredcall = MethodDefn(
3798                MethodDecl(p.enteredCallVar().name, methodspec=MethodSpec.OVERRIDE)
3799            )
3800            exitedcall = MethodDefn(
3801                MethodDecl(p.exitedCallVar().name, methodspec=MethodSpec.OVERRIDE)
3802            )
3803
3804            self.cls.addstmts(
3805                [
3806                    processingerror,
3807                    shouldcontinue,
3808                    entered,
3809                    exited,
3810                    enteredcall,
3811                    exitedcall,
3812                    Whitespace.NL,
3813                ]
3814            )
3815
3816        self.cls.addstmts(([Label.PUBLIC] + self.standardTypedefs() + [Whitespace.NL]))
3817
3818        self.cls.addstmt(Label.PUBLIC)
3819        # Actor()
3820        ctor = ConstructorDefn(ConstructorDecl(self.clsname))
3821        side = ExprVar("mozilla::ipc::" + self.side.title() + "Side")
3822        if ptype.isToplevel():
3823            name = ExprLiteral.String(_actorName(p.name, self.side))
3824            ctor.memberinits = [
3825                ExprMemberInit(
3826                    ExprVar("mozilla::ipc::IToplevelProtocol"),
3827                    [name, _protocolId(ptype), side],
3828                )
3829            ]
3830        else:
3831            ctor.memberinits = [
3832                ExprMemberInit(
3833                    ExprVar("mozilla::ipc::IProtocol"), [_protocolId(ptype), side]
3834                )
3835            ]
3836
3837        ctor.addcode("MOZ_COUNT_CTOR(${clsname});\n", clsname=self.clsname)
3838        self.cls.addstmts([ctor, Whitespace.NL])
3839
3840        # ~Actor()
3841        dtor = DestructorDefn(
3842            DestructorDecl(self.clsname, methodspec=MethodSpec.VIRTUAL)
3843        )
3844        dtor.addcode("MOZ_COUNT_DTOR(${clsname});\n", clsname=self.clsname)
3845
3846        self.cls.addstmts([dtor, Whitespace.NL])
3847
3848        if ptype.isRefcounted():
3849            self.cls.addcode(
3850                """
3851                NS_INLINE_DECL_PURE_VIRTUAL_REFCOUNTING
3852                """
3853            )
3854            self.cls.addstmt(Label.PROTECTED)
3855            self.cls.addcode(
3856                """
3857                void ActorAlloc() final { AddRef(); }
3858                void ActorDealloc() final { Release(); }
3859                """
3860            )
3861
3862        self.cls.addstmt(Label.PUBLIC)
3863        if not ptype.isToplevel():
3864            if 1 == len(p.managers):
3865                # manager() const
3866                managertype = p.managerActorType(self.side, ptr=True)
3867                managermeth = MethodDefn(
3868                    MethodDecl("Manager", ret=managertype, const=True)
3869                )
3870                managermeth.addcode(
3871                    """
3872                    return static_cast<${type}>(IProtocol::Manager());
3873                    """,
3874                    type=managertype,
3875                )
3876
3877                self.cls.addstmts([managermeth, Whitespace.NL])
3878
3879        def actorFromIter(itervar):
3880            return ExprCode("${iter}.Get()->GetKey()", iter=itervar)
3881
3882        def forLoopOverHashtable(hashtable, itervar, const=False):
3883            itermeth = "ConstIter" if const else "Iter"
3884            return StmtFor(
3885                init=ExprCode(
3886                    "auto ${itervar} = ${hashtable}.${itermeth}()",
3887                    itervar=itervar,
3888                    hashtable=hashtable,
3889                    itermeth=itermeth,
3890                ),
3891                cond=ExprCode("!${itervar}.Done()", itervar=itervar),
3892                update=ExprCode("${itervar}.Next()", itervar=itervar),
3893            )
3894
3895        # Managed[T](Array& inout) const
3896        # const Array<T>& Managed() const
3897        for managed in ptype.manages:
3898            container = p.managedVar(managed, self.side)
3899
3900            meth = MethodDefn(
3901                MethodDecl(
3902                    p.managedMethod(managed, self.side).name,
3903                    params=[
3904                        Decl(
3905                            _cxxArrayType(
3906                                p.managedCxxType(managed, self.side), ref=True
3907                            ),
3908                            "aArr",
3909                        )
3910                    ],
3911                    const=True,
3912                )
3913            )
3914            meth.addcode("${container}.ToArray(aArr);\n", container=container)
3915
3916            refmeth = MethodDefn(
3917                MethodDecl(
3918                    p.managedMethod(managed, self.side).name,
3919                    params=[],
3920                    ret=p.managedVarType(managed, self.side, const=True, ref=True),
3921                    const=True,
3922                )
3923            )
3924            refmeth.addcode("return ${container};\n", container=container)
3925
3926            self.cls.addstmts([meth, refmeth, Whitespace.NL])
3927
3928        # AllManagedActors(Array& inout) const
3929        arrvar = ExprVar("arr__")
3930        managedmeth = MethodDefn(
3931            MethodDecl(
3932                "AllManagedActors",
3933                params=[
3934                    Decl(
3935                        _cxxArrayType(_refptr(_cxxLifecycleProxyType()), ref=True),
3936                        arrvar.name,
3937                    )
3938                ],
3939                methodspec=MethodSpec.OVERRIDE,
3940                const=True,
3941            )
3942        )
3943
3944        # Count the number of managed actors, and allocate space in the output array.
3945        managedmeth.addcode(
3946            """
3947            uint32_t total = 0;
3948            """
3949        )
3950        for managed in ptype.manages:
3951            managedmeth.addcode(
3952                """
3953                total += ${container}.Count();
3954                """,
3955                container=p.managedVar(managed, self.side),
3956            )
3957        managedmeth.addcode(
3958            """
3959            arr__.SetCapacity(total);
3960
3961            """
3962        )
3963
3964        for managed in ptype.manages:
3965            managedmeth.addcode(
3966                """
3967                for (auto* key : ${container}) {
3968                    arr__.AppendElement(key->GetLifecycleProxy());
3969                }
3970
3971                """,
3972                container=p.managedVar(managed, self.side),
3973            )
3974
3975        self.cls.addstmts([managedmeth, Whitespace.NL])
3976
3977        # OpenPEndpoint(...)/BindPEndpoint(...)
3978        for managed in ptype.manages:
3979            self.genManagedEndpoint(managed)
3980
3981        # OnMessageReceived()/OnCallReceived()
3982
3983        # save these away for use in message handler case stmts
3984        msgvar = ExprVar("msg__")
3985        self.msgvar = msgvar
3986        replyvar = ExprVar("reply__")
3987        self.replyvar = replyvar
3988        itervar = ExprVar("iter__")
3989        self.itervar = itervar
3990        var = ExprVar("v__")
3991        self.var = var
3992        # for ctor recv cases, we can't read the actor ID into a PFoo*
3993        # because it doesn't exist on this side yet.  Use a "special"
3994        # actor handle instead
3995        handlevar = ExprVar("handle__")
3996        self.handlevar = handlevar
3997
3998        msgtype = ExprCode("msg__.type()")
3999        self.asyncSwitch = StmtSwitch(msgtype)
4000        self.syncSwitch = None
4001        self.interruptSwitch = None
4002        if toplevel.isSync() or toplevel.isInterrupt():
4003            self.syncSwitch = StmtSwitch(msgtype)
4004            if toplevel.isInterrupt():
4005                self.interruptSwitch = StmtSwitch(msgtype)
4006
4007        # implement Send*() methods and add dispatcher cases to
4008        # message switch()es
4009        for md in p.messageDecls:
4010            self.visitMessageDecl(md)
4011
4012        # add default cases
4013        default = StmtCode(
4014            """
4015            return MsgNotKnown;
4016            """
4017        )
4018        self.asyncSwitch.addcase(DefaultLabel(), default)
4019        if toplevel.isSync() or toplevel.isInterrupt():
4020            self.syncSwitch.addcase(DefaultLabel(), default)
4021            if toplevel.isInterrupt():
4022                self.interruptSwitch.addcase(DefaultLabel(), default)
4023
4024        self.cls.addstmts(self.implementManagerIface())
4025
4026        def makeHandlerMethod(name, switch, hasReply, dispatches=False):
4027            params = [Decl(Type("Message", const=True, ref=True), msgvar.name)]
4028            if hasReply:
4029                params.append(Decl(Type("Message", ref=True, ptr=True), replyvar.name))
4030
4031            method = MethodDefn(
4032                MethodDecl(
4033                    name,
4034                    methodspec=MethodSpec.OVERRIDE,
4035                    params=params,
4036                    ret=_Result.Type(),
4037                )
4038            )
4039
4040            if not switch:
4041                method.addcode(
4042                    """
4043                    MOZ_ASSERT_UNREACHABLE("message protocol not supported");
4044                    return MsgNotKnown;
4045                    """
4046                )
4047                return method
4048
4049            if dispatches:
4050                if hasReply:
4051                    ondeadactor = [StmtReturn(_Result.RouteError)]
4052                else:
4053                    ondeadactor = [
4054                        self.logMessage(
4055                            None, ExprAddrOf(msgvar), "Ignored message for dead actor"
4056                        ),
4057                        StmtReturn(_Result.Processed),
4058                    ]
4059
4060                method.addcode(
4061                    """
4062                    int32_t route__ = ${msgvar}.routing_id();
4063                    if (MSG_ROUTING_CONTROL != route__) {
4064                        IProtocol* routed__ = Lookup(route__);
4065                        if (!routed__ || !routed__->GetLifecycleProxy()) {
4066                            $*{ondeadactor}
4067                        }
4068
4069                        RefPtr<mozilla::ipc::ActorLifecycleProxy> proxy__ =
4070                            routed__->GetLifecycleProxy();
4071                        return proxy__->Get()->${name}($,{args});
4072                    }
4073
4074                    """,
4075                    msgvar=msgvar,
4076                    ondeadactor=ondeadactor,
4077                    name=name,
4078                    args=[p.name for p in params],
4079                )
4080
4081            # bug 509581: don't generate the switch stmt if there
4082            # is only the default case; MSVC doesn't like that
4083            if switch.nr_cases > 1:
4084                method.addstmt(switch)
4085            else:
4086                method.addstmt(StmtReturn(_Result.NotKnown))
4087
4088            return method
4089
4090        dispatches = ptype.isToplevel() and ptype.isManager()
4091        self.cls.addstmts(
4092            [
4093                makeHandlerMethod(
4094                    "OnMessageReceived",
4095                    self.asyncSwitch,
4096                    hasReply=False,
4097                    dispatches=dispatches,
4098                ),
4099                Whitespace.NL,
4100            ]
4101        )
4102        self.cls.addstmts(
4103            [
4104                makeHandlerMethod(
4105                    "OnMessageReceived",
4106                    self.syncSwitch,
4107                    hasReply=True,
4108                    dispatches=dispatches,
4109                ),
4110                Whitespace.NL,
4111            ]
4112        )
4113        self.cls.addstmts(
4114            [
4115                makeHandlerMethod(
4116                    "OnCallReceived",
4117                    self.interruptSwitch,
4118                    hasReply=True,
4119                    dispatches=dispatches,
4120                ),
4121                Whitespace.NL,
4122            ]
4123        )
4124
4125        clearsubtreevar = ExprVar("ClearSubtree")
4126
4127        if ptype.isToplevel():
4128            # OnChannelClose()
4129            onclose = MethodDefn(
4130                MethodDecl("OnChannelClose", methodspec=MethodSpec.OVERRIDE)
4131            )
4132            onclose.addcode(
4133                """
4134                DestroySubtree(NormalShutdown);
4135                ClearSubtree();
4136                DeallocShmems();
4137                if (GetLifecycleProxy()) {
4138                    GetLifecycleProxy()->Release();
4139                }
4140                """
4141            )
4142            self.cls.addstmts([onclose, Whitespace.NL])
4143
4144            # OnChannelError()
4145            onerror = MethodDefn(
4146                MethodDecl("OnChannelError", methodspec=MethodSpec.OVERRIDE)
4147            )
4148            onerror.addcode(
4149                """
4150                DestroySubtree(AbnormalShutdown);
4151                ClearSubtree();
4152                DeallocShmems();
4153                if (GetLifecycleProxy()) {
4154                    GetLifecycleProxy()->Release();
4155                }
4156                """
4157            )
4158            self.cls.addstmts([onerror, Whitespace.NL])
4159
4160        if ptype.isToplevel() and ptype.isInterrupt():
4161            processnative = MethodDefn(
4162                MethodDecl("ProcessNativeEventsInInterruptCall", ret=Type.VOID)
4163            )
4164            processnative.addcode(
4165                """
4166                #ifdef OS_WIN
4167                GetIPCChannel()->ProcessNativeEventsInInterruptCall();
4168                #else
4169                FatalError("This method is Windows-only");
4170                #endif
4171                """
4172            )
4173
4174            self.cls.addstmts([processnative, Whitespace.NL])
4175
4176        # private methods
4177        self.cls.addstmt(Label.PRIVATE)
4178
4179        # ClearSubtree()
4180        clearsubtree = MethodDefn(MethodDecl(clearsubtreevar.name))
4181        for managed in ptype.manages:
4182            clearsubtree.addcode(
4183                """
4184                for (auto* key : ${container}) {
4185                    key->ClearSubtree();
4186                }
4187                for (auto* key : ${container}) {
4188                    // Recursively releasing ${container} kids.
4189                    auto* proxy = key->GetLifecycleProxy();
4190                    NS_IF_RELEASE(proxy);
4191                }
4192                ${container}.Clear();
4193
4194                """,
4195                container=p.managedVar(managed, self.side),
4196            )
4197
4198        # don't release our own IPC reference: either the manager will do it,
4199        # or we're toplevel
4200        self.cls.addstmts([clearsubtree, Whitespace.NL])
4201
4202        for managed in ptype.manages:
4203            self.cls.addstmts(
4204                [
4205                    StmtDecl(
4206                        Decl(
4207                            p.managedVarType(managed, self.side),
4208                            p.managedVar(managed, self.side).name,
4209                        )
4210                    )
4211                ]
4212            )
4213
4214    def genManagedEndpoint(self, managed):
4215        hereEp = "ManagedEndpoint<%s>" % _actorName(managed.name(), self.side)
4216        thereEp = "ManagedEndpoint<%s>" % _actorName(
4217            managed.name(), _otherSide(self.side)
4218        )
4219
4220        actor = _HybridDecl(ipdl.type.ActorType(managed), "aActor")
4221
4222        # ManagedEndpoint<PThere> OpenPEndpoint(PHere* aActor)
4223        openmeth = MethodDefn(
4224            MethodDecl(
4225                "Open%sEndpoint" % managed.name(),
4226                params=[
4227                    Decl(self.protocol.managedCxxType(managed, self.side), actor.name)
4228                ],
4229                ret=Type(thereEp),
4230            )
4231        )
4232        openmeth.addcode(
4233            """
4234            $*{bind}
4235            return ${thereEp}(mozilla::ipc::PrivateIPDLInterface(), aActor->Id());
4236            """,
4237            bind=self.bindManagedActor(actor, errfn=ExprCall(ExprVar(thereEp))),
4238            thereEp=thereEp,
4239        )
4240
4241        # void BindPEndpoint(ManagedEndpoint<PHere>&& aEndpoint, PHere* aActor)
4242        bindmeth = MethodDefn(
4243            MethodDecl(
4244                "Bind%sEndpoint" % managed.name(),
4245                params=[
4246                    Decl(Type(hereEp), "aEndpoint"),
4247                    Decl(self.protocol.managedCxxType(managed, self.side), actor.name),
4248                ],
4249                ret=Type.BOOL,
4250            )
4251        )
4252        bindmeth.addcode(
4253            """
4254            MOZ_RELEASE_ASSERT(aEndpoint.ActorId(), "Invalid Endpoint!");
4255            $*{bind}
4256            return true;
4257            """,
4258            bind=self.bindManagedActor(
4259                actor, errfn=ExprLiteral.FALSE, idexpr=ExprCode("*aEndpoint.ActorId()")
4260            ),
4261        )
4262
4263        self.cls.addstmts([openmeth, bindmeth, Whitespace.NL])
4264
4265    def implementManagerIface(self):
4266        p = self.protocol
4267        protocolbase = Type("IProtocol", ptr=True)
4268
4269        methods = []
4270
4271        if p.decl.type.isToplevel():
4272
4273            # "private" message that passes shmem mappings from one process
4274            # to the other
4275            if p.subtreeUsesShmem():
4276                self.asyncSwitch.addcase(
4277                    CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"),
4278                    self.genShmemCreatedHandler(),
4279                )
4280                self.asyncSwitch.addcase(
4281                    CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"),
4282                    self.genShmemDestroyedHandler(),
4283                )
4284            else:
4285                abort = StmtBlock()
4286                abort.addstmts(
4287                    [
4288                        _fatalError("this protocol tree does not use shmem"),
4289                        StmtReturn(_Result.NotKnown),
4290                    ]
4291                )
4292                self.asyncSwitch.addcase(CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"), abort)
4293                self.asyncSwitch.addcase(
4294                    CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"), abort
4295                )
4296
4297        # Keep track of types created with an INOUT ctor. We need to call
4298        # Register() or RegisterID() for them depending on the side the managee
4299        # is created.
4300        inoutCtorTypes = []
4301        for msg in p.messageDecls:
4302            msgtype = msg.decl.type
4303            if msgtype.isCtor() and msgtype.isInout():
4304                inoutCtorTypes.append(msgtype.constructedType())
4305
4306        # all protocols share the "same" RemoveManagee() implementation
4307        pvar = ExprVar("aProtocolId")
4308        listenervar = ExprVar("aListener")
4309        removemanagee = MethodDefn(
4310            MethodDecl(
4311                p.removeManageeMethod().name,
4312                params=[
4313                    Decl(_protocolIdType(), pvar.name),
4314                    Decl(protocolbase, listenervar.name),
4315                ],
4316                methodspec=MethodSpec.OVERRIDE,
4317            )
4318        )
4319
4320        if not len(p.managesStmts):
4321            removemanagee.addcode(
4322                """
4323                FatalError("unreached");
4324                return;
4325                """
4326            )
4327        else:
4328            switchontype = StmtSwitch(pvar)
4329            for managee in p.managesStmts:
4330                manageeipdltype = managee.decl.type
4331                manageecxxtype = _cxxBareType(
4332                    ipdl.type.ActorType(manageeipdltype), self.side
4333                )
4334                case = ExprCode(
4335                    """
4336                    {
4337                        ${manageecxxtype} actor = static_cast<${manageecxxtype}>(aListener);
4338
4339                        const bool removed = ${container}.EnsureRemoved(actor);
4340                        MOZ_RELEASE_ASSERT(removed, "actor not managed by this!");
4341
4342                        auto* proxy = actor->GetLifecycleProxy();
4343                        NS_IF_RELEASE(proxy);
4344                        return;
4345                    }
4346                    """,
4347                    manageecxxtype=manageecxxtype,
4348                    container=p.managedVar(manageeipdltype, self.side),
4349                )
4350                switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case)
4351            switchontype.addcase(
4352                DefaultLabel(),
4353                ExprCode(
4354                    """
4355                FatalError("unreached");
4356                return;
4357                """
4358                ),
4359            )
4360            removemanagee.addstmt(switchontype)
4361
4362        # The `DeallocManagee` method is called for managed actors to trigger
4363        # deallocation when ActorLifecycleProxy is freed.
4364        deallocmanagee = MethodDefn(
4365            MethodDecl(
4366                p.deallocManageeMethod().name,
4367                params=[
4368                    Decl(_protocolIdType(), pvar.name),
4369                    Decl(protocolbase, listenervar.name),
4370                ],
4371                methodspec=MethodSpec.OVERRIDE,
4372            )
4373        )
4374
4375        if not len(p.managesStmts):
4376            deallocmanagee.addcode(
4377                """
4378                FatalError("unreached");
4379                return;
4380                """
4381            )
4382        else:
4383            switchontype = StmtSwitch(pvar)
4384            for managee in p.managesStmts:
4385                manageeipdltype = managee.decl.type
4386                # Reference counted actor types don't have corresponding
4387                # `Dealloc` methods, as they are deallocated by releasing the
4388                # IPDL-held reference.
4389                if manageeipdltype.isRefcounted():
4390                    continue
4391
4392                case = StmtCode(
4393                    """
4394                    ${concrete}->${dealloc}(static_cast<${type}>(aListener));
4395                    return;
4396                    """,
4397                    concrete=self.concreteThis(),
4398                    dealloc=_deallocMethod(manageeipdltype, self.side),
4399                    type=_cxxBareType(ipdl.type.ActorType(manageeipdltype), self.side),
4400                )
4401                switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case)
4402            switchontype.addcase(
4403                DefaultLabel(),
4404                StmtCode(
4405                    """
4406                FatalError("unreached");
4407                return;
4408                """
4409                ),
4410            )
4411            deallocmanagee.addstmt(switchontype)
4412
4413        return methods + [removemanagee, deallocmanagee, Whitespace.NL]
4414
4415    def genShmemCreatedHandler(self):
4416        assert self.protocol.decl.type.isToplevel()
4417
4418        return StmtCode(
4419            """
4420            {
4421                if (!ShmemCreated(${msgvar})) {
4422                    return MsgPayloadError;
4423                }
4424                return MsgProcessed;
4425            }
4426            """,
4427            msgvar=self.msgvar,
4428        )
4429
4430    def genShmemDestroyedHandler(self):
4431        assert self.protocol.decl.type.isToplevel()
4432
4433        return StmtCode(
4434            """
4435            {
4436                if (!ShmemDestroyed(${msgvar})) {
4437                    return MsgPayloadError;
4438                }
4439                return MsgProcessed;
4440            }
4441            """,
4442            msgvar=self.msgvar,
4443        )
4444
4445    # -------------------------------------------------------------------------
4446    # The next few functions are the crux of the IPDL code generator.
4447    # They generate code for all the nasty work of message
4448    # serialization/deserialization and dispatching handlers for
4449    # received messages.
4450    ##
4451
4452    def concreteThis(self):
4453        if (self.protocol.name, self.side) in VIRTUAL_CALL_CLASSES:
4454            return ExprVar.THIS
4455
4456        if (self.protocol.name, self.side) in DIRECT_CALL_OVERRIDES:
4457            (class_name, _) = DIRECT_CALL_OVERRIDES[self.protocol.name, self.side]
4458        else:
4459            assert self.protocol.name.startswith("P")
4460            class_name = "{}{}".format(self.protocol.name[1:], self.side.capitalize())
4461
4462        return ExprCode("static_cast<${class_name}*>(this)", class_name=class_name)
4463
4464    def thisCall(self, function, args):
4465        return ExprCall(ExprSelect(self.concreteThis(), "->", function), args=args)
4466
4467    def visitMessageDecl(self, md):
4468        isctor = md.decl.type.isCtor()
4469        isdtor = md.decl.type.isDtor()
4470        decltype = md.decl.type
4471        sendmethod = None
4472        movesendmethod = None
4473        promisesendmethod = None
4474        recvlbl, recvcase = None, None
4475
4476        def addRecvCase(lbl, case):
4477            if decltype.isAsync():
4478                self.asyncSwitch.addcase(lbl, case)
4479            elif decltype.isSync():
4480                self.syncSwitch.addcase(lbl, case)
4481            elif decltype.isInterrupt():
4482                self.interruptSwitch.addcase(lbl, case)
4483            else:
4484                assert 0
4485
4486        if self.sendsMessage(md):
4487            isasync = decltype.isAsync()
4488
4489            # NOTE: Don't generate helper ctors for refcounted types.
4490            #
4491            # Safety concerns around providing your own actor to a ctor (namely
4492            # that the return value won't be checked, and the argument will be
4493            # `delete`-ed) are less critical with refcounted actors, due to the
4494            # actor being held alive by the callsite.
4495            #
4496            # This allows refcounted actors to not implement crashing AllocPFoo
4497            # methods on the sending side.
4498            if isctor and not md.decl.type.constructedType().isRefcounted():
4499                self.cls.addstmts([self.genHelperCtor(md), Whitespace.NL])
4500
4501            if isctor and isasync:
4502                sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md)
4503            elif isctor:
4504                sendmethod = self.genBlockingCtorMethod(md)
4505            elif isdtor and isasync:
4506                sendmethod, (recvlbl, recvcase) = self.genAsyncDtor(md)
4507            elif isdtor:
4508                sendmethod = self.genBlockingDtorMethod(md)
4509            elif isasync:
4510                (
4511                    sendmethod,
4512                    movesendmethod,
4513                    promisesendmethod,
4514                    (recvlbl, recvcase),
4515                ) = self.genAsyncSendMethod(md)
4516            else:
4517                sendmethod, movesendmethod = self.genBlockingSendMethod(md)
4518
4519        # XXX figure out what to do here
4520        if isdtor and md.decl.type.constructedType().isToplevel():
4521            sendmethod = None
4522
4523        if sendmethod is not None:
4524            self.cls.addstmts([sendmethod, Whitespace.NL])
4525        if movesendmethod is not None:
4526            self.cls.addstmts([movesendmethod, Whitespace.NL])
4527        if promisesendmethod is not None:
4528            self.cls.addstmts([promisesendmethod, Whitespace.NL])
4529        if recvcase is not None:
4530            addRecvCase(recvlbl, recvcase)
4531            recvlbl, recvcase = None, None
4532
4533        if self.receivesMessage(md):
4534            if isctor:
4535                recvlbl, recvcase = self.genCtorRecvCase(md)
4536            elif isdtor:
4537                recvlbl, recvcase = self.genDtorRecvCase(md)
4538            else:
4539                recvlbl, recvcase = self.genRecvCase(md)
4540
4541            # XXX figure out what to do here
4542            if isdtor and md.decl.type.constructedType().isToplevel():
4543                return
4544
4545            addRecvCase(recvlbl, recvcase)
4546
4547    def genAsyncCtor(self, md):
4548        actor = md.actorDecl()
4549        method = MethodDefn(self.makeSendMethodDecl(md))
4550
4551        msgvar, stmts = self.makeMessage(md, errfnSendCtor)
4552        sendok, sendstmts = self.sendAsync(md, msgvar)
4553
4554        method.addcode(
4555            """
4556            $*{bind}
4557
4558            // Build our constructor message.
4559            $*{stmts}
4560
4561            // Notify the other side about the newly created actor. This can
4562            // fail if our manager has already been destroyed.
4563            //
4564            // NOTE: If the send call fails due to toplevel channel teardown,
4565            // the `IProtocol::ChannelSend` wrapper absorbs the error for us,
4566            // so we don't tear down actors unexpectedly.
4567            $*{sendstmts}
4568
4569            // Warn, destroy the actor, and return null if the message failed to
4570            // send. Otherwise, return the successfully created actor reference.
4571            if (!${sendok}) {
4572                NS_WARNING("Error sending ${actorname} constructor");
4573                $*{destroy}
4574                return nullptr;
4575            }
4576            return ${actor};
4577            """,
4578            bind=self.bindManagedActor(actor),
4579            stmts=stmts,
4580            sendstmts=sendstmts,
4581            sendok=sendok,
4582            destroy=self.destroyActor(
4583                md, actor.var(), why=_DestroyReason.FailedConstructor
4584            ),
4585            actor=actor.var(),
4586            actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
4587        )
4588
4589        lbl = CaseLabel(md.pqReplyId())
4590        case = StmtBlock()
4591        case.addstmt(StmtReturn(_Result.Processed))
4592        # TODO not really sure what to do with async ctor "replies" yet.
4593        # destroy actor if there was an error?  tricky ...
4594
4595        return method, (lbl, case)
4596
4597    def genBlockingCtorMethod(self, md):
4598        actor = md.actorDecl()
4599        method = MethodDefn(self.makeSendMethodDecl(md))
4600
4601        msgvar, stmts = self.makeMessage(md, errfnSendCtor)
4602
4603        replyvar = self.replyvar
4604        sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
4605        replystmts = self.deserializeReply(
4606            md,
4607            ExprAddrOf(replyvar),
4608            self.side,
4609            errfnSendCtor,
4610            errfnSentinel(ExprLiteral.NULL),
4611        )
4612
4613        method.addcode(
4614            """
4615            $*{bind}
4616
4617            // Build our constructor message.
4618            $*{stmts}
4619
4620            // Synchronously send the constructor message to the other side. If
4621            // the send fails, e.g. due to the remote side shutting down, the
4622            // actor will be destroyed and potentially freed.
4623            Message ${replyvar};
4624            $*{sendstmts}
4625
4626            if (!(${sendok})) {
4627                // Warn, destroy the actor and return null if the message
4628                // failed to send.
4629                NS_WARNING("Error sending constructor");
4630                $*{destroy}
4631                return nullptr;
4632            }
4633
4634            $*{replystmts}
4635            return ${actor};
4636            """,
4637            bind=self.bindManagedActor(actor),
4638            stmts=stmts,
4639            replyvar=replyvar,
4640            sendstmts=sendstmts,
4641            sendok=sendok,
4642            destroy=self.destroyActor(
4643                md, actor.var(), why=_DestroyReason.FailedConstructor
4644            ),
4645            replystmts=replystmts,
4646            actor=actor.var(),
4647            actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
4648        )
4649
4650        return method
4651
4652    def bindManagedActor(self, actordecl, errfn=ExprLiteral.NULL, idexpr=None):
4653        actorproto = actordecl.ipdltype.protocol
4654
4655        if idexpr is None:
4656            setManagerArgs = [ExprVar.THIS]
4657        else:
4658            setManagerArgs = [ExprVar.THIS, idexpr]
4659
4660        return [
4661            StmtCode(
4662                """
4663            if (!${actor}) {
4664                NS_WARNING("Cannot bind null ${actorname} actor");
4665                return ${errfn};
4666            }
4667
4668            ${actor}->SetManagerAndRegister($,{setManagerArgs});
4669            ${container}.Insert(${actor});
4670            """,
4671                actor=actordecl.var(),
4672                actorname=actorproto.name() + self.side.capitalize(),
4673                errfn=errfn,
4674                setManagerArgs=setManagerArgs,
4675                container=self.protocol.managedVar(actorproto, self.side),
4676            )
4677        ]
4678
4679    def genHelperCtor(self, md):
4680        helperdecl = self.makeSendMethodDecl(md)
4681        helperdecl.params = helperdecl.params[1:]
4682        helper = MethodDefn(helperdecl)
4683
4684        helper.addstmts(
4685            [
4686                self.callAllocActor(md, retsems="out", side=self.side),
4687                StmtReturn(ExprCall(ExprVar(helperdecl.name), args=md.makeCxxArgs())),
4688            ]
4689        )
4690        return helper
4691
4692    def genAsyncDtor(self, md):
4693        actor = md.actorDecl()
4694        actorvar = actor.var()
4695        method = MethodDefn(self.makeDtorMethodDecl(md))
4696
4697        method.addstmt(self.dtorPrologue(actorvar))
4698
4699        msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar)
4700        sendok, sendstmts = self.sendAsync(md, msgvar, actorvar)
4701        method.addstmts(
4702            stmts
4703            + sendstmts
4704            + [Whitespace.NL]
4705            + self.dtorEpilogue(md, actor.var())
4706            + [StmtReturn(sendok)]
4707        )
4708
4709        lbl = CaseLabel(md.pqReplyId())
4710        case = StmtBlock()
4711        case.addstmt(StmtReturn(_Result.Processed))
4712        # TODO if the dtor is "inherently racy", keep the actor alive
4713        # until the other side acks
4714
4715        return method, (lbl, case)
4716
4717    def genBlockingDtorMethod(self, md):
4718        actor = md.actorDecl()
4719        actorvar = actor.var()
4720        method = MethodDefn(self.makeDtorMethodDecl(md))
4721
4722        method.addstmt(self.dtorPrologue(actorvar))
4723
4724        msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar)
4725
4726        replyvar = self.replyvar
4727        sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar, actorvar)
4728        method.addstmts(
4729            stmts
4730            + [Whitespace.NL, StmtDecl(Decl(Type("Message"), replyvar.name))]
4731            + sendstmts
4732        )
4733
4734        destmts = self.deserializeReply(
4735            md, ExprAddrOf(replyvar), self.side, errfnSend, errfnSentinel(), actorvar
4736        )
4737        ifsendok = StmtIf(ExprLiteral.FALSE)
4738        ifsendok.addifstmts(destmts)
4739        ifsendok.addifstmts(
4740            [Whitespace.NL, StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, "&="))]
4741        )
4742
4743        method.addstmt(ifsendok)
4744
4745        method.addstmts(
4746            self.dtorEpilogue(md, actor.var()) + [Whitespace.NL, StmtReturn(sendok)]
4747        )
4748
4749        return method
4750
4751    def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion):
4752        if md.decl.type.isCtor():
4753            destroyedType = md.decl.type.constructedType()
4754        else:
4755            destroyedType = self.protocol.decl.type
4756
4757        return [
4758            StmtCode(
4759                """
4760            IProtocol* mgr = ${actor}->Manager();
4761            ${actor}->DestroySubtree(${why});
4762            ${actor}->ClearSubtree();
4763            mgr->RemoveManagee(${protoId}, ${actor});
4764            """,
4765                actor=actorexpr,
4766                why=why,
4767                protoId=_protocolId(destroyedType),
4768            )
4769        ]
4770
4771    def dtorPrologue(self, actorexpr):
4772        return StmtCode(
4773            """
4774            if (!${actor} || !${actor}->CanSend()) {
4775                NS_WARNING("Attempt to __delete__ missing or closed actor");
4776                return false;
4777            }
4778            """,
4779            actor=actorexpr,
4780        )
4781
4782    def dtorEpilogue(self, md, actorexpr):
4783        return self.destroyActor(md, actorexpr)
4784
4785    def genRecvAsyncReplyCase(self, md):
4786        lbl = CaseLabel(md.pqReplyId())
4787        case = StmtBlock()
4788        resolve, reason, prologue, desrej, desstmts = self.deserializeAsyncReply(
4789            md, self.side, errfnRecv, errfnSentinel(_Result.ValuError)
4790        )
4791
4792        if len(md.returns) > 1:
4793            resolvetype = _tuple([d.bareType(self.side) for d in md.returns])
4794            resolvearg = ExprCall(
4795                ExprVar("MakeTuple"), args=[ExprMove(p.var()) for p in md.returns]
4796            )
4797        else:
4798            resolvetype = md.returns[0].bareType(self.side)
4799            resolvearg = ExprMove(md.returns[0].var())
4800
4801        case.addcode(
4802            """
4803            $*{prologue}
4804
4805            UniquePtr<MessageChannel::UntypedCallbackHolder> untypedCallback =
4806                GetIPCChannel()->PopCallback(${msgvar});
4807
4808            typedef MessageChannel::CallbackHolder<${resolvetype}> CallbackHolder;
4809            auto* callback = static_cast<CallbackHolder*>(untypedCallback.get());
4810            if (!callback) {
4811                FatalError("Error unknown callback");
4812                return MsgProcessingError;
4813            }
4814
4815            if (${resolve}) {
4816                $*{desstmts}
4817                callback->Resolve(${resolvearg});
4818            } else {
4819                $*{desrej}
4820                callback->Reject(std::move(${reason}));
4821            }
4822            return MsgProcessed;
4823            """,
4824            prologue=prologue,
4825            msgvar=self.msgvar,
4826            resolve=resolve,
4827            resolvetype=resolvetype,
4828            desstmts=desstmts,
4829            resolvearg=resolvearg,
4830            desrej=desrej,
4831            reason=reason,
4832        )
4833
4834        return (lbl, case)
4835
4836    def genAsyncSendMethod(self, md):
4837        method = MethodDefn(self.makeSendMethodDecl(md))
4838        msgvar, stmts = self.makeMessage(md, errfnSend)
4839        retvar, sendstmts = self.sendAsync(md, msgvar)
4840
4841        method.addstmts(stmts + [Whitespace.NL] + sendstmts + [StmtReturn(retvar)])
4842
4843        movemethod = None
4844
4845        # Add the promise overload if we need one.
4846        if md.returns:
4847            promisemethod = MethodDefn(self.makeSendMethodDecl(md, promise=True))
4848            stmts = self.sendAsyncWithPromise(md)
4849            promisemethod.addstmts(stmts)
4850
4851            (lbl, case) = self.genRecvAsyncReplyCase(md)
4852        else:
4853            (promisemethod, lbl, case) = (None, None, None)
4854
4855        return method, movemethod, promisemethod, (lbl, case)
4856
4857    def genBlockingSendMethod(self, md, fromActor=None):
4858        method = MethodDefn(self.makeSendMethodDecl(md))
4859
4860        msgvar, serstmts = self.makeMessage(md, errfnSend, fromActor)
4861        replyvar = self.replyvar
4862
4863        sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
4864        failif = StmtIf(ExprNot(sendok))
4865        failif.addifstmt(StmtReturn.FALSE)
4866
4867        desstmts = self.deserializeReply(
4868            md, ExprAddrOf(replyvar), self.side, errfnSend, errfnSentinel()
4869        )
4870
4871        method.addstmts(
4872            serstmts
4873            + [Whitespace.NL, StmtDecl(Decl(Type("Message"), replyvar.name))]
4874            + sendstmts
4875            + [failif]
4876            + desstmts
4877            + [Whitespace.NL, StmtReturn.TRUE]
4878        )
4879
4880        movemethod = None
4881
4882        return method, movemethod
4883
4884    def genCtorRecvCase(self, md):
4885        lbl = CaseLabel(md.pqMsgId())
4886        case = StmtBlock()
4887        actorhandle = self.handlevar
4888
4889        stmts = self.deserializeMessage(
4890            md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
4891        )
4892
4893        idvar, saveIdStmts = self.saveActorId(md)
4894        case.addstmts(
4895            stmts
4896            + [
4897                StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
4898                for r in md.returns
4899            ]
4900            # alloc the actor, register it under the foreign ID
4901            + [self.callAllocActor(md, retsems="in", side=self.side)]
4902            + self.bindManagedActor(
4903                md.actorDecl(), errfn=_Result.ValuError, idexpr=_actorHId(actorhandle)
4904            )
4905            + [Whitespace.NL]
4906            + saveIdStmts
4907            + self.invokeRecvHandler(md)
4908            + self.makeReply(md, errfnRecv, idvar)
4909            + [Whitespace.NL, StmtReturn(_Result.Processed)]
4910        )
4911
4912        return lbl, case
4913
4914    def genDtorRecvCase(self, md):
4915        lbl = CaseLabel(md.pqMsgId())
4916        case = StmtBlock()
4917
4918        stmts = self.deserializeMessage(
4919            md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
4920        )
4921
4922        idvar, saveIdStmts = self.saveActorId(md)
4923        case.addstmts(
4924            stmts
4925            + [
4926                StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
4927                for r in md.returns
4928            ]
4929            + self.invokeRecvHandler(md, implicit=False)
4930            + [Whitespace.NL]
4931            + saveIdStmts
4932            + self.makeReply(md, errfnRecv, routingId=idvar)
4933            + [Whitespace.NL]
4934            + self.dtorEpilogue(md, md.actorDecl().var())
4935            + [Whitespace.NL, StmtReturn(_Result.Processed)]
4936        )
4937
4938        return lbl, case
4939
4940    def genRecvCase(self, md):
4941        lbl = CaseLabel(md.pqMsgId())
4942        case = StmtBlock()
4943
4944        stmts = self.deserializeMessage(
4945            md, self.side, errfn=errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
4946        )
4947
4948        idvar, saveIdStmts = self.saveActorId(md)
4949        declstmts = [
4950            StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
4951            for r in md.returns
4952        ]
4953        if md.decl.type.isAsync() and md.returns:
4954            declstmts = self.makeResolver(md, errfnRecv, routingId=idvar)
4955        case.addstmts(
4956            stmts
4957            + saveIdStmts
4958            + declstmts
4959            + self.invokeRecvHandler(md)
4960            + [Whitespace.NL]
4961            + self.makeReply(md, errfnRecv, routingId=idvar)
4962            + [StmtReturn(_Result.Processed)]
4963        )
4964
4965        return lbl, case
4966
4967    # helper methods
4968
4969    def makeMessage(self, md, errfn, fromActor=None):
4970        msgvar = self.msgvar
4971        routingId = self.protocol.routingId(fromActor)
4972        this = ExprVar.THIS
4973        if md.decl.type.isDtor():
4974            this = md.actorDecl().var()
4975
4976        stmts = (
4977            [
4978                StmtDecl(
4979                    Decl(Type("IPC::Message", ptr=True), msgvar.name),
4980                    init=ExprCall(ExprVar(md.pqMsgCtorFunc()), args=[routingId]),
4981                )
4982            ]
4983            + [Whitespace.NL]
4984            + [
4985                _ParamTraits.checkedWrite(
4986                    p.ipdltype, p.var(), msgvar, sentinelKey=p.name, actor=this
4987                )
4988                for p in md.params
4989            ]
4990            + [Whitespace.NL]
4991            + self.setMessageFlags(md, msgvar)
4992        )
4993        return msgvar, stmts
4994
4995    def makeResolver(self, md, errfn, routingId):
4996        if routingId is None:
4997            routingId = self.protocol.routingId()
4998        if not md.decl.type.isAsync() or not md.hasReply():
4999            return []
5000
5001        def paramValue(idx):
5002            assert idx < len(md.returns)
5003            if len(md.returns) > 1:
5004                return ExprCode("mozilla::Get<${idx}>(aParam)", idx=idx)
5005            return ExprVar("aParam")
5006
5007        serializeParams = [
5008            _ParamTraits.checkedWrite(
5009                p.ipdltype,
5010                paramValue(idx),
5011                self.replyvar,
5012                sentinelKey=p.name,
5013                actor=ExprVar("self__"),
5014            )
5015            for idx, p in enumerate(md.returns)
5016        ]
5017
5018        return [
5019            StmtCode(
5020                """
5021                UniquePtr<IPC::Message> ${replyvar}(${replyCtor}(${routingId}));
5022                ${replyvar}->set_seqno(${msgvar}.seqno());
5023
5024                RefPtr<mozilla::ipc::IPDLResolverInner> resolver__ =
5025                    new mozilla::ipc::IPDLResolverInner(std::move(${replyvar}), this);
5026
5027                ${resolvertype} resolver = [resolver__ = std::move(resolver__)](${resolveType} aParam) {
5028                    resolver__->Resolve([&] (IPC::Message* ${replyvar}, IProtocol* self__) {
5029                        $*{serializeParams}
5030                        ${logSendingReply}
5031                    });
5032                };
5033                """,
5034                msgvar=self.msgvar,
5035                resolvertype=Type(md.resolverName()),
5036                routingId=routingId,
5037                resolveType=_resolveType(md.returns, self.side),
5038                replyvar=self.replyvar,
5039                replyCtor=ExprVar(md.pqReplyCtorFunc()),
5040                serializeParams=serializeParams,
5041                logSendingReply=self.logMessage(
5042                    md,
5043                    self.replyvar,
5044                    "Sending reply ",
5045                    actor=ExprVar("self__"),
5046                ),
5047            )
5048        ]
5049
5050    def makeReply(self, md, errfn, routingId):
5051        if routingId is None:
5052            routingId = self.protocol.routingId()
5053        # TODO special cases for async ctor/dtor replies
5054        if not md.decl.type.hasReply():
5055            return []
5056        if md.decl.type.isAsync() and md.decl.type.hasReply():
5057            return []
5058
5059        replyvar = self.replyvar
5060        return (
5061            [
5062                StmtExpr(
5063                    ExprAssn(
5064                        replyvar,
5065                        ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[routingId]),
5066                    )
5067                ),
5068                Whitespace.NL,
5069            ]
5070            + [
5071                _ParamTraits.checkedWrite(
5072                    r.ipdltype,
5073                    r.var(),
5074                    replyvar,
5075                    sentinelKey=r.name,
5076                    actor=ExprVar.THIS,
5077                )
5078                for r in md.returns
5079            ]
5080            + self.setMessageFlags(md, replyvar)
5081            + [self.logMessage(md, replyvar, "Sending reply ")]
5082        )
5083
5084    def setMessageFlags(self, md, var, seqno=None):
5085        stmts = []
5086
5087        if seqno:
5088            stmts.append(
5089                StmtExpr(ExprCall(ExprSelect(var, "->", "set_seqno"), args=[seqno]))
5090            )
5091
5092        return stmts + [Whitespace.NL]
5093
5094    def deserializeMessage(self, md, side, errfn, errfnSent):
5095        msgvar = self.msgvar
5096        itervar = self.itervar
5097        msgexpr = ExprAddrOf(msgvar)
5098        isctor = md.decl.type.isCtor()
5099        stmts = [
5100            self.logMessage(md, msgexpr, "Received ", receiving=True),
5101            self.profilerLabel(md),
5102            Whitespace.NL,
5103        ]
5104
5105        if 0 == len(md.params):
5106            return stmts
5107
5108        start, decls, reads = 0, [], []
5109        if isctor:
5110            # return the raw actor handle so that its ID can be used
5111            # to construct the "real" actor
5112            handlevar = self.handlevar
5113            handletype = Type("ActorHandle")
5114            decls = [StmtDecl(Decl(handletype, handlevar.name), initargs=[])]
5115            reads = [
5116                _ParamTraits.checkedRead(
5117                    None,
5118                    ExprAddrOf(handlevar),
5119                    msgexpr,
5120                    ExprAddrOf(self.itervar),
5121                    errfn,
5122                    "'%s'" % handletype.name,
5123                    sentinelKey="actor",
5124                    errfnSentinel=errfnSent,
5125                    actor=ExprVar.THIS,
5126                )
5127            ]
5128            start = 1
5129
5130        decls.extend(
5131            [
5132                StmtDecl(
5133                    Decl(
5134                        (
5135                            Type("Tainted", T=p.bareType(side))
5136                            if md.decl.type.tainted and "NoTaint" not in p.attributes
5137                            else p.bareType(side)
5138                        ),
5139                        p.var().name,
5140                    ),
5141                    initargs=[],
5142                )
5143                for p in md.params[start:]
5144            ]
5145        )
5146        reads.extend(
5147            [
5148                _ParamTraits.checkedRead(
5149                    p.ipdltype,
5150                    ExprAddrOf(p.var()),
5151                    msgexpr,
5152                    ExprAddrOf(itervar),
5153                    errfn,
5154                    "'%s'" % p.ipdltype.name(),
5155                    sentinelKey=p.name,
5156                    errfnSentinel=errfnSent,
5157                    actor=ExprVar.THIS,
5158                )
5159                for p in md.params[start:]
5160            ]
5161        )
5162
5163        stmts.extend(
5164            (
5165                [
5166                    StmtDecl(
5167                        Decl(_iterType(ptr=False), self.itervar.name), initargs=[msgvar]
5168                    )
5169                ]
5170                + decls
5171                + [Whitespace.NL]
5172                + reads
5173                + [self.endRead(msgvar, itervar)]
5174            )
5175        )
5176
5177        return stmts
5178
5179    def deserializeAsyncReply(self, md, side, errfn, errfnSent):
5180        msgvar = self.msgvar
5181        itervar = self.itervar
5182        msgexpr = ExprAddrOf(msgvar)
5183        isctor = md.decl.type.isCtor()
5184        resolve = ExprVar("resolve__")
5185        reason = ExprVar("reason__")
5186
5187        # NOTE: The `resolve__` and `reason__` parameters don't have sentinels,
5188        # as they are serialized by the IPDLResolverInner type in
5189        # ProtocolUtils.cpp rather than by generated code.
5190        desresolve = [
5191            StmtCode(
5192                """
5193                bool resolve__ = false;
5194                if (!ReadIPDLParam(${msgexpr}, &${itervar}, this, &resolve__)) {
5195                    FatalError("Error deserializing bool");
5196                    return MsgValueError;
5197                }
5198                """,
5199                msgexpr=msgexpr,
5200                itervar=itervar,
5201            ),
5202        ]
5203        desrej = [
5204            StmtCode(
5205                """
5206                ResponseRejectReason reason__{};
5207                if (!ReadIPDLParam(${msgexpr}, &${itervar}, this, &reason__)) {
5208                    FatalError("Error deserializing ResponseRejectReason");
5209                    return MsgValueError;
5210                }
5211                """,
5212                msgexpr=msgexpr,
5213                itervar=itervar,
5214            ),
5215            self.endRead(msgvar, itervar),
5216        ]
5217        prologue = [
5218            self.logMessage(md, msgexpr, "Received ", receiving=True),
5219            self.profilerLabel(md),
5220            Whitespace.NL,
5221        ]
5222
5223        if not md.returns:
5224            return prologue
5225
5226        prologue.extend(
5227            [StmtDecl(Decl(_iterType(ptr=False), itervar.name), initargs=[msgvar])]
5228            + desresolve
5229        )
5230
5231        start, decls, reads = 0, [], []
5232        if isctor:
5233            # return the raw actor handle so that its ID can be used
5234            # to construct the "real" actor
5235            handlevar = self.handlevar
5236            handletype = Type("ActorHandle")
5237            decls = [StmtDecl(Decl(handletype, handlevar.name), initargs=[])]
5238            reads = [
5239                _ParamTraits.checkedRead(
5240                    None,
5241                    ExprAddrOf(handlevar),
5242                    msgexpr,
5243                    ExprAddrOf(itervar),
5244                    errfn,
5245                    "'%s'" % handletype.name,
5246                    sentinelKey="actor",
5247                    errfnSentinel=errfnSent,
5248                    actor=ExprVar.THIS,
5249                )
5250            ]
5251            start = 1
5252
5253        stmts = (
5254            decls
5255            + [
5256                StmtDecl(Decl(p.bareType(side), p.var().name), initargs=[])
5257                for p in md.returns
5258            ]
5259            + [Whitespace.NL]
5260            + reads
5261            + [
5262                _ParamTraits.checkedRead(
5263                    p.ipdltype,
5264                    ExprAddrOf(p.var()),
5265                    msgexpr,
5266                    ExprAddrOf(itervar),
5267                    errfn,
5268                    "'%s'" % p.ipdltype.name(),
5269                    sentinelKey=p.name,
5270                    errfnSentinel=errfnSent,
5271                    actor=ExprVar.THIS,
5272                )
5273                for p in md.returns[start:]
5274            ]
5275            + [self.endRead(msgvar, itervar)]
5276        )
5277
5278        return resolve, reason, prologue, desrej, stmts
5279
5280    def deserializeReply(
5281        self, md, replyexpr, side, errfn, errfnSentinel, actor=None, decls=False
5282    ):
5283        stmts = [
5284            Whitespace.NL,
5285            self.logMessage(md, replyexpr, "Received reply ", actor, receiving=True),
5286        ]
5287        if 0 == len(md.returns):
5288            return stmts
5289
5290        itervar = self.itervar
5291        declstmts = []
5292        if decls:
5293            declstmts = [
5294                StmtDecl(Decl(p.bareType(side), p.var().name), initargs=[])
5295                for p in md.returns
5296            ]
5297        stmts.extend(
5298            [
5299                Whitespace.NL,
5300                StmtDecl(
5301                    Decl(_iterType(ptr=False), itervar.name), initargs=[self.replyvar]
5302                ),
5303            ]
5304            + declstmts
5305            + [Whitespace.NL]
5306            + [
5307                _ParamTraits.checkedRead(
5308                    r.ipdltype,
5309                    r.var(),
5310                    ExprAddrOf(self.replyvar),
5311                    ExprAddrOf(self.itervar),
5312                    errfn,
5313                    "'%s'" % r.ipdltype.name(),
5314                    sentinelKey=r.name,
5315                    errfnSentinel=errfnSentinel,
5316                    actor=ExprVar.THIS,
5317                )
5318                for r in md.returns
5319            ]
5320            + [self.endRead(self.replyvar, itervar)]
5321        )
5322
5323        return stmts
5324
5325    def sendAsync(self, md, msgexpr, actor=None):
5326        sendok = ExprVar("sendok__")
5327        resolvefn = ExprVar("aResolve")
5328        rejectfn = ExprVar("aReject")
5329
5330        stmts = [
5331            Whitespace.NL,
5332            self.logMessage(md, msgexpr, "Sending ", actor),
5333            self.profilerLabel(md),
5334        ]
5335        stmts.append(Whitespace.NL)
5336
5337        # Generate the actual call expression.
5338        send = ExprVar("ChannelSend")
5339        if actor is not None:
5340            send = ExprSelect(actor, "->", send.name)
5341        if md.returns:
5342            stmts.append(
5343                StmtExpr(
5344                    ExprCall(
5345                        send, args=[msgexpr, ExprMove(resolvefn), ExprMove(rejectfn)]
5346                    )
5347                )
5348            )
5349            retvar = None
5350        else:
5351            stmts.append(
5352                StmtDecl(
5353                    Decl(Type.BOOL, sendok.name), init=ExprCall(send, args=[msgexpr])
5354                )
5355            )
5356            retvar = sendok
5357
5358        return (retvar, stmts)
5359
5360    def sendBlocking(self, md, msgexpr, replyexpr, actor=None):
5361        send = ExprVar("ChannelSend")
5362        if md.decl.type.isInterrupt():
5363            send = ExprVar("ChannelCall")
5364        if actor is not None:
5365            send = ExprSelect(actor, "->", send.name)
5366
5367        sendok = ExprVar("sendok__")
5368        self.externalIncludes.add("mozilla/ProfilerMarkers.h")
5369        return (
5370            sendok,
5371            (
5372                [
5373                    Whitespace.NL,
5374                    self.logMessage(md, msgexpr, "Sending ", actor),
5375                    self.profilerLabel(md),
5376                ]
5377                + [
5378                    Whitespace.NL,
5379                    StmtDecl(Decl(Type.BOOL, sendok.name), init=ExprLiteral.FALSE),
5380                    StmtBlock(
5381                        [
5382                            StmtExpr(
5383                                ExprCall(
5384                                    ExprVar("AUTO_PROFILER_TRACING_MARKER"),
5385                                    [
5386                                        ExprLiteral.String("Sync IPC"),
5387                                        ExprLiteral.String(
5388                                            self.protocol.name
5389                                            + "::"
5390                                            + md.prettyMsgName()
5391                                        ),
5392                                        ExprVar("IPC"),
5393                                    ],
5394                                )
5395                            ),
5396                            StmtExpr(
5397                                ExprAssn(
5398                                    sendok,
5399                                    ExprCall(
5400                                        send, args=[msgexpr, ExprAddrOf(replyexpr)]
5401                                    ),
5402                                )
5403                            ),
5404                        ]
5405                    ),
5406                ]
5407            ),
5408        )
5409
5410    def sendAsyncWithPromise(self, md):
5411        # Create a new promise, and forward to the callback send overload.
5412        promise = _makePromise(md.returns, self.side, resolver=True)
5413
5414        if len(md.returns) > 1:
5415            resolvetype = _tuple([d.bareType(self.side) for d in md.returns])
5416        else:
5417            resolvetype = md.returns[0].bareType(self.side)
5418
5419        resolve = ExprCode(
5420            """
5421            [promise__](${resolvetype}&& aValue) {
5422                promise__->Resolve(std::move(aValue), __func__);
5423            }
5424            """,
5425            resolvetype=resolvetype,
5426        )
5427        reject = ExprCode(
5428            """
5429            [promise__](ResponseRejectReason&& aReason) {
5430                promise__->Reject(std::move(aReason), __func__);
5431            }
5432            """,
5433            resolvetype=resolvetype,
5434        )
5435
5436        args = [ExprMove(p.var()) for p in md.params] + [resolve, reject]
5437        stmt = StmtCode(
5438            """
5439            RefPtr<${promise}> promise__ = new ${promise}(__func__);
5440            promise__->UseDirectTaskDispatch(__func__);
5441            ${send}($,{args});
5442            return promise__;
5443            """,
5444            promise=promise,
5445            send=md.sendMethod(),
5446            args=args,
5447        )
5448        return [stmt]
5449
5450    def callAllocActor(self, md, retsems, side):
5451        actortype = md.actorDecl().bareType(self.side)
5452        if md.decl.type.constructedType().isRefcounted():
5453            actortype.ptr = False
5454            actortype = _refptr(actortype)
5455
5456        callalloc = self.thisCall(
5457            _allocMethod(md.decl.type.constructedType(), side),
5458            args=md.makeCxxArgs(retsems=retsems, retcallsems="out", implicit=False),
5459        )
5460
5461        return StmtDecl(Decl(actortype, md.actorDecl().var().name), init=callalloc)
5462
5463    def invokeRecvHandler(self, md, implicit=True):
5464        retsems = "in"
5465        if md.decl.type.isAsync() and md.returns:
5466            retsems = "resolver"
5467        failif = StmtIf(
5468            ExprNot(
5469                self.thisCall(
5470                    md.recvMethod(),
5471                    md.makeCxxArgs(
5472                        paramsems="move",
5473                        retsems=retsems,
5474                        retcallsems="out",
5475                        implicit=implicit,
5476                    ),
5477                )
5478            )
5479        )
5480        failif.addifstmts(
5481            [
5482                _protocolErrorBreakpoint("Handler returned error code!"),
5483                Whitespace(
5484                    "// Error handled in mozilla::ipc::IPCResult\n", indent=True
5485                ),
5486                StmtReturn(_Result.ProcessingError),
5487            ]
5488        )
5489        return [failif]
5490
5491    def makeDtorMethodDecl(self, md):
5492        decl = self.makeSendMethodDecl(md)
5493        decl.methodspec = MethodSpec.STATIC
5494        return decl
5495
5496    def makeSendMethodDecl(self, md, promise=False, paramsems="in"):
5497        implicit = md.decl.type.hasImplicitActorParam()
5498        if md.decl.type.isAsync() and md.returns:
5499            if promise:
5500                returnsems = "promise"
5501                rettype = _refptr(Type(md.promiseName()))
5502            else:
5503                returnsems = "callback"
5504                rettype = Type.VOID
5505        else:
5506            assert not promise
5507            returnsems = "out"
5508            rettype = Type.BOOL
5509        decl = MethodDecl(
5510            md.sendMethod(),
5511            params=md.makeCxxParams(
5512                paramsems,
5513                returnsems=returnsems,
5514                side=self.side,
5515                implicit=implicit,
5516                direction="send",
5517            ),
5518            warn_unused=(
5519                (self.side == "parent" and returnsems != "callback")
5520                or (md.decl.type.isCtor() and not md.decl.type.isAsync())
5521            ),
5522            ret=rettype,
5523        )
5524        if md.decl.type.isCtor():
5525            decl.ret = md.actorDecl().bareType(self.side)
5526        return decl
5527
5528    def logMessage(self, md, msgptr, pfx, actor=None, receiving=False):
5529        actorname = _actorName(self.protocol.name, self.side)
5530        return StmtCode(
5531            """
5532            if (mozilla::ipc::LoggingEnabledFor(${actorname})) {
5533                mozilla::ipc::LogMessageForProtocol(
5534                    ${actorname},
5535                    ${otherpid},
5536                    ${pfx},
5537                    ${msgptr}->type(),
5538                    mozilla::ipc::MessageDirection::${direction});
5539            }
5540            """,
5541            actorname=ExprLiteral.String(actorname),
5542            otherpid=self.protocol.callOtherPid(actor),
5543            pfx=ExprLiteral.String(pfx),
5544            msgptr=msgptr,
5545            direction="eReceiving" if receiving else "eSending",
5546        )
5547
5548    def profilerLabel(self, md):
5549        self.externalIncludes.add("mozilla/ProfilerLabels.h")
5550        return StmtCode(
5551            """
5552            AUTO_PROFILER_LABEL("${name}::${msgname}", OTHER);
5553            """,
5554            name=self.protocol.name,
5555            msgname=md.prettyMsgName(),
5556        )
5557
5558    def saveActorId(self, md):
5559        idvar = ExprVar("id__")
5560        if md.decl.type.hasReply():
5561            # only save the ID if we're actually going to use it, to
5562            # avoid unused-variable warnings
5563            saveIdStmts = [
5564                StmtDecl(Decl(_actorIdType(), idvar.name), self.protocol.routingId())
5565            ]
5566        else:
5567            saveIdStmts = []
5568        return idvar, saveIdStmts
5569
5570    def endRead(self, msgexpr, iterexpr):
5571        return StmtCode(
5572            """
5573            ${msg}.EndRead(${iter}, ${msg}.type());
5574            """,
5575            msg=msgexpr,
5576            iter=iterexpr,
5577        )
5578
5579
5580class _GenerateProtocolParentCode(_GenerateProtocolActorCode):
5581    def __init__(self):
5582        _GenerateProtocolActorCode.__init__(self, "parent")
5583
5584    def sendsMessage(self, md):
5585        return not md.decl.type.isIn()
5586
5587    def receivesMessage(self, md):
5588        return md.decl.type.isInout() or md.decl.type.isIn()
5589
5590
5591class _GenerateProtocolChildCode(_GenerateProtocolActorCode):
5592    def __init__(self):
5593        _GenerateProtocolActorCode.__init__(self, "child")
5594
5595    def sendsMessage(self, md):
5596        return not md.decl.type.isOut()
5597
5598    def receivesMessage(self, md):
5599        return md.decl.type.isInout() or md.decl.type.isOut()
5600
5601
5602# -----------------------------------------------------------------------------
5603# Utility passes
5604##
5605
5606
5607def _splitClassDeclDefn(cls):
5608    """Destructively split |cls| methods into declarations and
5609    definitions (if |not methodDecl.force_inline|).  Return classDecl,
5610    methodDefns."""
5611    defns = Block()
5612
5613    for i, stmt in enumerate(cls.stmts):
5614        if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline:
5615            decl, defn = _splitMethodDeclDefn(stmt, cls)
5616            cls.stmts[i] = StmtDecl(decl)
5617            if defn:
5618                defns.addstmts([defn, Whitespace.NL])
5619
5620    return cls, defns
5621
5622
5623def _splitMethodDeclDefn(md, cls):
5624    # Pure methods have decls but no defns.
5625    if md.decl.methodspec == MethodSpec.PURE:
5626        return md.decl, None
5627
5628    saveddecl = deepcopy(md.decl)
5629    md.decl.cls = cls
5630    # Don't emit method specifiers on method defns.
5631    md.decl.methodspec = MethodSpec.NONE
5632    md.decl.warn_unused = False
5633    md.decl.only_for_definition = True
5634    for param in md.decl.params:
5635        if isinstance(param, Param):
5636            param.default = None
5637    return saveddecl, md
5638
5639
5640def _splitFuncDeclDefn(fun):
5641    assert not fun.decl.force_inline
5642    return StmtDecl(fun.decl), fun
5643