1# This Source Code Form is subject to the terms of the Mozilla Public 2# License, v. 2.0. If a copy of the MPL was not distributed with this 3# file, You can obtain one at http://mozilla.org/MPL/2.0/. 4 5import re 6from copy import deepcopy 7from collections import OrderedDict 8import itertools 9 10import ipdl.ast 11import ipdl.builtin 12from ipdl.cxx.ast import * 13from ipdl.cxx.code import * 14from ipdl.direct_call import VIRTUAL_CALL_CLASSES, DIRECT_CALL_OVERRIDES 15from ipdl.type import ActorType, UnionType, TypeVisitor, builtinHeaderIncludes 16from ipdl.util import hash_str 17 18 19# ----------------------------------------------------------------------------- 20# "Public" interface to lowering 21## 22 23 24class LowerToCxx: 25 def lower(self, tu, segmentcapacitydict): 26 """returns |[ header: File ], [ cpp : File ]| representing the 27 lowered form of |tu|""" 28 # annotate the AST with IPDL/C++ IR-type stuff used later 29 tu.accept(_DecorateWithCxxStuff()) 30 31 # Any modifications to the filename scheme here need corresponding 32 # modifications in the ipdl.py driver script. 33 name = tu.name 34 pheader, pcpp = File(name + ".h"), File(name + ".cpp") 35 36 _GenerateProtocolCode().lower(tu, pheader, pcpp, segmentcapacitydict) 37 headers = [pheader] 38 cpps = [pcpp] 39 40 if tu.protocol: 41 pname = tu.protocol.name 42 43 parentheader, parentcpp = ( 44 File(pname + "Parent.h"), 45 File(pname + "Parent.cpp"), 46 ) 47 _GenerateProtocolParentCode().lower( 48 tu, pname + "Parent", parentheader, parentcpp 49 ) 50 51 childheader, childcpp = File(pname + "Child.h"), File(pname + "Child.cpp") 52 _GenerateProtocolChildCode().lower( 53 tu, pname + "Child", childheader, childcpp 54 ) 55 56 headers += [parentheader, childheader] 57 cpps += [parentcpp, childcpp] 58 59 return headers, cpps 60 61 62# ----------------------------------------------------------------------------- 63# Helper code 64## 65 66 67def hashfunc(value): 68 h = hash_str(value) % 2 ** 32 69 if h < 0: 70 h += 2 ** 32 71 return h 72 73 74_NULL_ACTOR_ID = ExprLiteral.ZERO 75_FREED_ACTOR_ID = ExprLiteral.ONE 76 77_DISCLAIMER = Whitespace( 78 """// 79// Automatically generated by ipdlc. 80// Edit at your own risk 81// 82 83""" 84) 85 86 87class _struct: 88 pass 89 90 91def _namespacedHeaderName(name, namespaces): 92 pfx = "/".join([ns.name for ns in namespaces]) 93 if pfx: 94 return pfx + "/" + name 95 else: 96 return name 97 98 99def _ipdlhHeaderName(tu): 100 assert tu.filetype == "header" 101 return _namespacedHeaderName(tu.name, tu.namespaces) 102 103 104def _protocolHeaderName(p, side=""): 105 if side: 106 side = side.title() 107 base = p.name + side 108 return _namespacedHeaderName(base, p.namespaces) 109 110 111def _includeGuardMacroName(headerfile): 112 return re.sub(r"[./]", "_", headerfile.name) 113 114 115def _includeGuardStart(headerfile): 116 guard = _includeGuardMacroName(headerfile) 117 return [CppDirective("ifndef", guard), CppDirective("define", guard)] 118 119 120def _includeGuardEnd(headerfile): 121 guard = _includeGuardMacroName(headerfile) 122 return [CppDirective("endif", "// ifndef " + guard)] 123 124 125def _messageStartName(ptype): 126 return ptype.name() + "MsgStart" 127 128 129def _protocolId(ptype): 130 return ExprVar(_messageStartName(ptype)) 131 132 133def _protocolIdType(): 134 return Type.INT32 135 136 137def _actorName(pname, side): 138 """|pname| is the protocol name. |side| is 'Parent' or 'Child'.""" 139 tag = side 140 if not tag[0].isupper(): 141 tag = side.title() 142 return pname + tag 143 144 145def _actorIdType(): 146 return Type.INT32 147 148 149def _actorTypeTagType(): 150 return Type.INT32 151 152 153def _actorId(actor=None): 154 if actor is not None: 155 return ExprCall(ExprSelect(actor, "->", "Id")) 156 return ExprCall(ExprVar("Id")) 157 158 159def _actorHId(actorhandle): 160 return ExprSelect(actorhandle, ".", "mId") 161 162 163def _backstagePass(): 164 return ExprCall(ExprVar("mozilla::ipc::PrivateIPDLInterface")) 165 166 167def _iterType(ptr): 168 return Type("PickleIterator", ptr=ptr) 169 170 171def _deleteId(): 172 return ExprVar("Msg___delete____ID") 173 174 175def _deleteReplyId(): 176 return ExprVar("Reply___delete____ID") 177 178 179def _lookupListener(idexpr): 180 return ExprCall(ExprVar("Lookup"), args=[idexpr]) 181 182 183def _makeForwardDeclForQClass(clsname, quals, cls=True, struct=False): 184 fd = ForwardDecl(clsname, cls=cls, struct=struct) 185 if 0 == len(quals): 186 return fd 187 188 outerns = Namespace(quals[0]) 189 innerns = outerns 190 for ns in quals[1:]: 191 tmpns = Namespace(ns) 192 innerns.addstmt(tmpns) 193 innerns = tmpns 194 195 innerns.addstmt(fd) 196 return outerns 197 198 199def _makeForwardDeclForActor(ptype, side): 200 return _makeForwardDeclForQClass( 201 _actorName(ptype.qname.baseid, side), ptype.qname.quals 202 ) 203 204 205def _makeForwardDecl(type): 206 return _makeForwardDeclForQClass(type.name(), type.qname.quals) 207 208 209def _putInNamespaces(cxxthing, namespaces): 210 """|namespaces| is in order [ outer, ..., inner ]""" 211 if 0 == len(namespaces): 212 return cxxthing 213 214 outerns = Namespace(namespaces[0].name) 215 innerns = outerns 216 for ns in namespaces[1:]: 217 newns = Namespace(ns.name) 218 innerns.addstmt(newns) 219 innerns = newns 220 innerns.addstmt(cxxthing) 221 return outerns 222 223 224def _sendPrefix(msgtype): 225 """Prefix of the name of the C++ method that sends |msgtype|.""" 226 if msgtype.isInterrupt(): 227 return "Call" 228 return "Send" 229 230 231def _recvPrefix(msgtype): 232 """Prefix of the name of the C++ method that handles |msgtype|.""" 233 if msgtype.isInterrupt(): 234 return "Answer" 235 return "Recv" 236 237 238def _flatTypeName(ipdltype): 239 """Return a 'flattened' IPDL type name that can be used as an 240 identifier. 241 E.g., |Foo[]| --> |ArrayOfFoo|.""" 242 # NB: this logic depends heavily on what IPDL types are allowed to 243 # be constructed; e.g., Foo[][] is disallowed. needs to be kept in 244 # sync with grammar. 245 if ipdltype.isIPDL() and ipdltype.isArray(): 246 return "ArrayOf" + ipdltype.basetype.name() 247 if ipdltype.isIPDL() and ipdltype.isMaybe(): 248 return "Maybe" + ipdltype.basetype.name() 249 return ipdltype.name() 250 251 252def _hasVisibleActor(ipdltype): 253 """Return true iff a C++ decl of |ipdltype| would have an Actor* type. 254 For example: |Actor[]| would turn into |Array<ActorParent*>|, so this 255 function would return true for |Actor[]|.""" 256 return ipdltype.isIPDL() and ( 257 ipdltype.isActor() 258 or (ipdltype.hasBaseType() and _hasVisibleActor(ipdltype.basetype)) 259 ) 260 261 262def _abortIfFalse(cond, msg): 263 return StmtExpr( 264 ExprCall(ExprVar("MOZ_RELEASE_ASSERT"), [cond, ExprLiteral.String(msg)]) 265 ) 266 267 268def _refptr(T): 269 return Type("RefPtr", T=T) 270 271 272def _uniqueptr(T): 273 return Type("UniquePtr", T=T) 274 275 276def _alreadyaddrefed(T): 277 return Type("already_AddRefed", T=T) 278 279 280def _tuple(types, const=False, ref=False): 281 return Type("Tuple", T=types, const=const, ref=ref) 282 283 284def _promise(resolvetype, rejecttype, tail, resolver=False): 285 inner = Type("Private") if resolver else None 286 return Type("MozPromise", T=[resolvetype, rejecttype, tail], inner=inner) 287 288 289def _makePromise(returns, side, resolver=False): 290 if len(returns) > 1: 291 resolvetype = _tuple([d.bareType(side) for d in returns]) 292 else: 293 resolvetype = returns[0].bareType(side) 294 295 # MozPromise is purposefully made to be exclusive only. Really, we mean it. 296 return _promise( 297 resolvetype, _ResponseRejectReason.Type(), ExprLiteral.TRUE, resolver=resolver 298 ) 299 300 301def _resolveType(returns, side): 302 if len(returns) > 1: 303 return _tuple([d.moveType(side) for d in returns]) 304 return returns[0].moveType(side) 305 306 307def _makeResolver(returns, side): 308 return TypeFunction([Decl(_resolveType(returns, side), "")]) 309 310 311def _cxxArrayType(basetype, const=False, ref=False): 312 return Type("nsTArray", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False) 313 314 315def _cxxMaybeType(basetype, const=False, ref=False): 316 return Type( 317 "mozilla::Maybe", 318 T=basetype, 319 const=const, 320 ref=ref, 321 hasimplicitcopyctor=basetype.hasimplicitcopyctor, 322 ) 323 324 325def _cxxManagedContainerType(basetype, const=False, ref=False): 326 return Type( 327 "ManagedContainer", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False 328 ) 329 330 331def _cxxLifecycleProxyType(ptr=False): 332 return Type("mozilla::ipc::ActorLifecycleProxy", ptr=ptr) 333 334 335def _otherSide(side): 336 if side == "child": 337 return "parent" 338 if side == "parent": 339 return "child" 340 assert 0 341 342 343def _ifLogging(topLevelProtocol, stmts): 344 return StmtCode( 345 """ 346 if (mozilla::ipc::LoggingEnabledFor(${proto})) { 347 $*{stmts} 348 } 349 """, 350 proto=topLevelProtocol, 351 stmts=stmts, 352 ) 353 354 355# XXX we need to remove these and install proper error handling 356 357 358def _printErrorMessage(msg): 359 if isinstance(msg, str): 360 msg = ExprLiteral.String(msg) 361 return StmtExpr(ExprCall(ExprVar("NS_ERROR"), args=[msg])) 362 363 364def _protocolErrorBreakpoint(msg): 365 if isinstance(msg, str): 366 msg = ExprLiteral.String(msg) 367 return StmtExpr( 368 ExprCall(ExprVar("mozilla::ipc::ProtocolErrorBreakpoint"), args=[msg]) 369 ) 370 371 372def _printWarningMessage(msg): 373 if isinstance(msg, str): 374 msg = ExprLiteral.String(msg) 375 return StmtExpr(ExprCall(ExprVar("NS_WARNING"), args=[msg])) 376 377 378def _fatalError(msg): 379 return StmtExpr(ExprCall(ExprVar("FatalError"), args=[ExprLiteral.String(msg)])) 380 381 382def _logicError(msg): 383 return StmtExpr( 384 ExprCall(ExprVar("mozilla::ipc::LogicError"), args=[ExprLiteral.String(msg)]) 385 ) 386 387 388def _sentinelReadError(classname): 389 return StmtExpr( 390 ExprCall( 391 ExprVar("mozilla::ipc::SentinelReadError"), 392 args=[ExprLiteral.String(classname)], 393 ) 394 ) 395 396 397# Results that IPDL-generated code returns back to *Channel code. 398# Users never see these 399 400 401class _Result: 402 @staticmethod 403 def Type(): 404 return Type("Result") 405 406 Processed = ExprVar("MsgProcessed") 407 NotKnown = ExprVar("MsgNotKnown") 408 NotAllowed = ExprVar("MsgNotAllowed") 409 PayloadError = ExprVar("MsgPayloadError") 410 ProcessingError = ExprVar("MsgProcessingError") 411 RouteError = ExprVar("MsgRouteError") 412 ValuError = ExprVar("MsgValueError") # [sic] 413 414 415# these |errfn*| are functions that generate code to be executed on an 416# error, such as "bad actor ID". each is given a Python string 417# containing a description of the error 418 419# used in user-facing Send*() methods 420 421 422def errfnSend(msg, errcode=ExprLiteral.FALSE): 423 return [_fatalError(msg), StmtReturn(errcode)] 424 425 426def errfnSendCtor(msg): 427 return errfnSend(msg, errcode=ExprLiteral.NULL) 428 429 430# TODO should this error handling be strengthened for dtors? 431 432 433def errfnSendDtor(msg): 434 return [_printErrorMessage(msg), StmtReturn.FALSE] 435 436 437# used in |OnMessage*()| handlers that hand in-messages off to Recv*() 438# interface methods 439 440 441def errfnRecv(msg, errcode=_Result.ValuError): 442 return [_fatalError(msg), StmtReturn(errcode)] 443 444 445def errfnSentinel(rvalue=ExprLiteral.FALSE): 446 def inner(msg): 447 return [_sentinelReadError(msg), StmtReturn(rvalue)] 448 449 return inner 450 451 452def _destroyMethod(): 453 return ExprVar("ActorDestroy") 454 455 456def errfnUnreachable(msg): 457 return [_logicError(msg)] 458 459 460class _DestroyReason: 461 @staticmethod 462 def Type(): 463 return Type("ActorDestroyReason") 464 465 Deletion = ExprVar("Deletion") 466 AncestorDeletion = ExprVar("AncestorDeletion") 467 NormalShutdown = ExprVar("NormalShutdown") 468 AbnormalShutdown = ExprVar("AbnormalShutdown") 469 FailedConstructor = ExprVar("FailedConstructor") 470 471 472class _ResponseRejectReason: 473 @staticmethod 474 def Type(): 475 return Type("ResponseRejectReason") 476 477 SendError = ExprVar("ResponseRejectReason::SendError") 478 ChannelClosed = ExprVar("ResponseRejectReason::ChannelClosed") 479 HandlerRejected = ExprVar("ResponseRejectReason::HandlerRejected") 480 ActorDestroyed = ExprVar("ResponseRejectReason::ActorDestroyed") 481 482 483# ----------------------------------------------------------------------------- 484# Intermediate representation (IR) nodes used during lowering 485 486 487class _ConvertToCxxType(TypeVisitor): 488 def __init__(self, side, fq): 489 self.side = side 490 self.fq = fq 491 492 def typename(self, thing): 493 if self.fq: 494 return thing.fullname() 495 return thing.name() 496 497 def visitImportedCxxType(self, t): 498 cxxtype = Type(self.typename(t)) 499 if t.isRefcounted(): 500 cxxtype = _refptr(cxxtype) 501 return cxxtype 502 503 def visitActorType(self, a): 504 return Type(_actorName(self.typename(a.protocol), self.side), ptr=True) 505 506 def visitStructType(self, s): 507 return Type(self.typename(s)) 508 509 def visitUnionType(self, u): 510 return Type(self.typename(u)) 511 512 def visitArrayType(self, a): 513 basecxxtype = a.basetype.accept(self) 514 return _cxxArrayType(basecxxtype) 515 516 def visitMaybeType(self, m): 517 basecxxtype = m.basetype.accept(self) 518 return _cxxMaybeType(basecxxtype) 519 520 def visitShmemType(self, s): 521 return Type(self.typename(s)) 522 523 def visitByteBufType(self, s): 524 return Type(self.typename(s)) 525 526 def visitFDType(self, s): 527 return Type(self.typename(s)) 528 529 def visitEndpointType(self, s): 530 return Type(self.typename(s)) 531 532 def visitManagedEndpointType(self, s): 533 return Type(self.typename(s)) 534 535 def visitUniquePtrType(self, s): 536 return Type(self.typename(s)) 537 538 def visitProtocolType(self, p): 539 assert 0 540 541 def visitMessageType(self, m): 542 assert 0 543 544 def visitVoidType(self, v): 545 assert 0 546 547 548def _cxxBareType(ipdltype, side, fq=False): 549 return ipdltype.accept(_ConvertToCxxType(side, fq)) 550 551 552def _cxxRefType(ipdltype, side): 553 t = _cxxBareType(ipdltype, side) 554 t.ref = True 555 return t 556 557 558def _cxxConstRefType(ipdltype, side): 559 t = _cxxBareType(ipdltype, side) 560 if ipdltype.isIPDL() and ipdltype.isActor(): 561 return t 562 if ipdltype.isIPDL() and ipdltype.isShmem(): 563 t.ref = True 564 return t 565 if ipdltype.isIPDL() and ipdltype.isByteBuf(): 566 t.ref = True 567 return t 568 if ipdltype.isIPDL() and ipdltype.hasBaseType(): 569 # Keep same constness as inner type. 570 inner = _cxxConstRefType(ipdltype.basetype, side) 571 t.const = inner.const or not inner.ref 572 t.ref = True 573 return t 574 if ipdltype.isCxx() and ipdltype.isMoveonly(): 575 t.const = True 576 t.ref = True 577 return t 578 if ipdltype.isCxx() and ipdltype.isRefcounted(): 579 # Use T* instead of const RefPtr<T>& 580 t = t.T 581 t.ptr = True 582 return t 583 if ipdltype.isUniquePtr(): 584 t.ref = True 585 return t 586 t.const = True 587 t.ref = True 588 return t 589 590 591def _cxxTypeCanMoveSend(ipdltype): 592 return ipdltype.isUniquePtr() 593 594 595def _cxxTypeNeedsMove(ipdltype): 596 if _cxxTypeNeedsMoveForSend(ipdltype): 597 return True 598 599 if ipdltype.isIPDL(): 600 return ipdltype.isArray() 601 602 return False 603 604 605def _cxxTypeNeedsMoveForSend(ipdltype): 606 if ipdltype.isUniquePtr(): 607 return True 608 609 if ipdltype.isCxx(): 610 return ipdltype.isMoveonly() 611 612 if ipdltype.isIPDL(): 613 if ipdltype.hasBaseType(): 614 return _cxxTypeNeedsMove(ipdltype.basetype) 615 return ( 616 ipdltype.isShmem() 617 or ipdltype.isByteBuf() 618 or ipdltype.isEndpoint() 619 or ipdltype.isManagedEndpoint() 620 ) 621 622 return False 623 624 625# FIXME Bug 1547019 This should be the same as _cxxTypeNeedsMoveForSend, but 626# a lot of existing code needs to be updated and fixed before 627# we can do that. 628def _cxxTypeCanOnlyMove(ipdltype, visited=None): 629 if visited is None: 630 visited = set() 631 632 visited.add(ipdltype) 633 634 if ipdltype.isCxx(): 635 return ipdltype.isMoveonly() 636 637 if ipdltype.isIPDL(): 638 if ipdltype.isMaybe() or ipdltype.isArray(): 639 return _cxxTypeCanOnlyMove(ipdltype.basetype, visited) 640 if ipdltype.isStruct() or ipdltype.isUnion(): 641 return any( 642 _cxxTypeCanOnlyMove(t, visited) 643 for t in ipdltype.itercomponents() 644 if t not in visited 645 ) 646 return ipdltype.isManagedEndpoint() 647 648 return False 649 650 651def _cxxTypeCanMove(ipdltype): 652 return not (ipdltype.isIPDL() and ipdltype.isActor()) 653 654 655def _cxxMoveRefType(ipdltype, side): 656 t = _cxxBareType(ipdltype, side) 657 if _cxxTypeNeedsMove(ipdltype): 658 t.rvalref = True 659 return t 660 return _cxxConstRefType(ipdltype, side) 661 662 663def _cxxForceMoveRefType(ipdltype, side): 664 assert _cxxTypeCanMove(ipdltype) 665 t = _cxxBareType(ipdltype, side) 666 t.rvalref = True 667 return t 668 669 670def _cxxPtrToType(ipdltype, side): 671 t = _cxxBareType(ipdltype, side) 672 if ipdltype.isIPDL() and ipdltype.isActor(): 673 t.ptr = False 674 t.ptrptr = True 675 return t 676 t.ptr = True 677 return t 678 679 680def _cxxConstPtrToType(ipdltype, side): 681 t = _cxxBareType(ipdltype, side) 682 if ipdltype.isIPDL() and ipdltype.isActor(): 683 t.ptr = False 684 t.ptrconstptr = True 685 return t 686 t.const = True 687 t.ptr = True 688 return t 689 690 691def _allocMethod(ptype, side): 692 return "Alloc" + ptype.name() + side.title() 693 694 695def _deallocMethod(ptype, side): 696 return "Dealloc" + ptype.name() + side.title() 697 698 699## 700# A _HybridDecl straddles IPDL and C++ decls. It knows which C++ 701# types correspond to which IPDL types, and it also knows how 702# serialize and deserialize "special" IPDL C++ types. 703## 704 705 706class _HybridDecl: 707 """A hybrid decl stores both an IPDL type and all the C++ type 708 info needed by later passes, along with a basic name for the decl.""" 709 710 def __init__(self, ipdltype, name, attributes={}): 711 self.ipdltype = ipdltype 712 self.name = name 713 self.attributes = attributes 714 715 def var(self): 716 return ExprVar(self.name) 717 718 def bareType(self, side, fq=False): 719 """Return this decl's unqualified C++ type.""" 720 return _cxxBareType(self.ipdltype, side, fq=fq) 721 722 def refType(self, side): 723 """Return this decl's C++ type as a 'reference' type, which is not 724 necessarily a C++ reference.""" 725 return _cxxRefType(self.ipdltype, side) 726 727 def constRefType(self, side): 728 """Return this decl's C++ type as a const, 'reference' type.""" 729 return _cxxConstRefType(self.ipdltype, side) 730 731 def rvalueRefType(self, side): 732 """Return this decl's C++ type as an r-value 'reference' type.""" 733 return _cxxMoveRefType(self.ipdltype, side) 734 735 def ptrToType(self, side): 736 return _cxxPtrToType(self.ipdltype, side) 737 738 def constPtrToType(self, side): 739 return _cxxConstPtrToType(self.ipdltype, side) 740 741 def inType(self, side): 742 """Return this decl's C++ Type with inparam semantics.""" 743 if self.ipdltype.isIPDL() and self.ipdltype.isActor(): 744 return self.bareType(side) 745 elif _cxxTypeNeedsMoveForSend(self.ipdltype): 746 return self.rvalueRefType(side) 747 return self.constRefType(side) 748 749 def moveType(self, side): 750 """Return this decl's C++ Type with move semantics.""" 751 if self.ipdltype.isIPDL() and self.ipdltype.isActor(): 752 return self.bareType(side) 753 return self.rvalueRefType(side) 754 755 def outType(self, side): 756 """Return this decl's C++ Type with outparam semantics.""" 757 t = self.bareType(side) 758 if self.ipdltype.isIPDL() and self.ipdltype.isActor(): 759 t.ptr = False 760 t.ptrptr = True 761 return t 762 t.ptr = True 763 return t 764 765 def forceMoveType(self, side): 766 """Return this decl's C++ Type with forced move semantics.""" 767 assert _cxxTypeCanMove(self.ipdltype) 768 return _cxxForceMoveRefType(self.ipdltype, side) 769 770 771# -------------------------------------------------- 772 773 774class HasFQName: 775 def fqClassName(self): 776 return self.decl.type.fullname() 777 778 779class _CompoundTypeComponent(_HybridDecl): 780 def __init__(self, ipdltype, name, side, ct): 781 _HybridDecl.__init__(self, ipdltype, name) 782 self.side = side 783 self.special = _hasVisibleActor(ipdltype) 784 785 # @override the following methods to pass |self.side| instead of 786 # forcing the caller to remember which side we're declared to 787 # represent. 788 def bareType(self, side=None, fq=False): 789 return _HybridDecl.bareType(self, self.side, fq=fq) 790 791 def refType(self, side=None): 792 return _HybridDecl.refType(self, self.side) 793 794 def constRefType(self, side=None): 795 return _HybridDecl.constRefType(self, self.side) 796 797 def ptrToType(self, side=None): 798 return _HybridDecl.ptrToType(self, self.side) 799 800 def constPtrToType(self, side=None): 801 return _HybridDecl.constPtrToType(self, self.side) 802 803 def inType(self, side=None): 804 return _HybridDecl.inType(self, self.side) 805 806 def forceMoveType(self, side=None): 807 return _HybridDecl.forceMoveType(self, self.side) 808 809 810class StructDecl(ipdl.ast.StructDecl, HasFQName): 811 def fields_ipdl_order(self): 812 for f in self.fields: 813 yield f 814 815 def fields_member_order(self): 816 assert len(self.packed_field_order) == len(self.fields) 817 818 for i in self.packed_field_order: 819 yield self.fields[i] 820 821 @staticmethod 822 def upgrade(structDecl): 823 assert isinstance(structDecl, ipdl.ast.StructDecl) 824 structDecl.__class__ = StructDecl 825 826 827class _StructField(_CompoundTypeComponent): 828 def __init__(self, ipdltype, name, sd, side=None): 829 self.basename = name 830 fname = name 831 special = _hasVisibleActor(ipdltype) 832 if special: 833 fname += side.title() 834 835 _CompoundTypeComponent.__init__(self, ipdltype, fname, side, sd) 836 837 def getMethod(self, thisexpr=None, sel="."): 838 meth = self.var() 839 if thisexpr is not None: 840 return ExprSelect(thisexpr, sel, meth.name) 841 return meth 842 843 def refExpr(self, thisexpr=None): 844 ref = self.memberVar() 845 if thisexpr is not None: 846 ref = ExprSelect(thisexpr, ".", ref.name) 847 return ref 848 849 def constRefExpr(self, thisexpr=None): 850 # sigh, gross hack 851 refexpr = self.refExpr(thisexpr) 852 if "Shmem" == self.ipdltype.name(): 853 refexpr = ExprCast(refexpr, Type("Shmem", ref=True), const=True) 854 if "ByteBuf" == self.ipdltype.name(): 855 refexpr = ExprCast(refexpr, Type("ByteBuf", ref=True), const=True) 856 if "FileDescriptor" == self.ipdltype.name(): 857 refexpr = ExprCast(refexpr, Type("FileDescriptor", ref=True), const=True) 858 return refexpr 859 860 def argVar(self): 861 return ExprVar("_" + self.name) 862 863 def memberVar(self): 864 return ExprVar(self.name + "_") 865 866 867class UnionDecl(ipdl.ast.UnionDecl, HasFQName): 868 def callType(self, var=None): 869 func = ExprVar("type") 870 if var is not None: 871 func = ExprSelect(var, ".", func.name) 872 return ExprCall(func) 873 874 @staticmethod 875 def upgrade(unionDecl): 876 assert isinstance(unionDecl, ipdl.ast.UnionDecl) 877 unionDecl.__class__ = UnionDecl 878 879 880class _UnionMember(_CompoundTypeComponent): 881 """Not in the AFL sense, but rather a member (e.g. |int;|) of an 882 IPDL union type.""" 883 884 def __init__(self, ipdltype, ud, side=None, other=None): 885 flatname = _flatTypeName(ipdltype) 886 special = _hasVisibleActor(ipdltype) 887 if special: 888 flatname += side.title() 889 890 _CompoundTypeComponent.__init__(self, ipdltype, "V" + flatname, side, ud) 891 self.flattypename = flatname 892 if special: 893 if other is not None: 894 self.other = other 895 else: 896 self.other = _UnionMember(ipdltype, ud, _otherSide(side), self) 897 898 # To create a finite object with a mutually recursive type, a union must 899 # be present somewhere in the recursive loop. Because of that we only 900 # need to care about introducing indirections inside unions. 901 self.recursive = ud.decl.type.mutuallyRecursiveWith(ipdltype) 902 903 def enum(self): 904 return "T" + self.flattypename 905 906 def enumvar(self): 907 return ExprVar(self.enum()) 908 909 def internalType(self): 910 if self.recursive: 911 return self.ptrToType() 912 else: 913 return self.bareType() 914 915 def unionType(self): 916 """Type used for storage in generated C union decl.""" 917 if self.recursive: 918 return self.ptrToType() 919 else: 920 return Type("mozilla::AlignedStorage2", T=self.internalType()) 921 922 def unionValue(self): 923 # NB: knows that Union's storage C union is named |mValue| 924 return ExprSelect(ExprVar("mValue"), ".", self.name) 925 926 def typedef(self): 927 return self.flattypename + "__tdef" 928 929 def callGetConstPtr(self): 930 """Return an expression of type self.constptrToSelfType()""" 931 return ExprCall(ExprVar(self.getConstPtrName())) 932 933 def callGetPtr(self): 934 """Return an expression of type self.ptrToSelfType()""" 935 return ExprCall(ExprVar(self.getPtrName())) 936 937 def callOperatorEq(self, rhs): 938 if self.ipdltype.isIPDL() and self.ipdltype.isActor(): 939 rhs = ExprCast(rhs, self.bareType(), const=True) 940 elif ( 941 self.ipdltype.isIPDL() 942 and self.ipdltype.isArray() 943 and not isinstance(rhs, ExprMove) 944 ): 945 rhs = ExprCall(ExprSelect(rhs, ".", "Clone"), args=[]) 946 return ExprAssn(ExprDeref(self.callGetPtr()), rhs) 947 948 def callCtor(self, expr=None): 949 assert not isinstance(expr, list) 950 951 if expr is None: 952 args = None 953 elif self.ipdltype.isIPDL() and self.ipdltype.isActor(): 954 args = [ExprCast(expr, self.bareType(), const=True)] 955 elif ( 956 self.ipdltype.isIPDL() 957 and self.ipdltype.isArray() 958 and not isinstance(expr, ExprMove) 959 ): 960 args = [ExprCall(ExprSelect(expr, ".", "Clone"), args=[])] 961 else: 962 args = [expr] 963 964 if self.recursive: 965 return ExprAssn( 966 self.callGetPtr(), ExprNew(self.bareType(self.side), args=args) 967 ) 968 else: 969 return ExprNew( 970 self.bareType(self.side), 971 args=args, 972 newargs=[ExprVar("mozilla::KnownNotNull"), self.callGetPtr()], 973 ) 974 975 def callDtor(self): 976 if self.recursive: 977 return ExprDelete(self.callGetPtr()) 978 else: 979 return ExprCall(ExprSelect(self.callGetPtr(), "->", "~" + self.typedef())) 980 981 def getTypeName(self): 982 return "get_" + self.flattypename 983 984 def getConstTypeName(self): 985 return "get_" + self.flattypename 986 987 def getOtherTypeName(self): 988 return "get_" + self.otherflattypename 989 990 def getPtrName(self): 991 return "ptr_" + self.flattypename 992 993 def getConstPtrName(self): 994 return "constptr_" + self.flattypename 995 996 def ptrToSelfExpr(self): 997 """|*ptrToSelfExpr()| has type |self.bareType()|""" 998 v = self.unionValue() 999 if self.recursive: 1000 return v 1001 else: 1002 return ExprCall(ExprSelect(v, ".", "addr")) 1003 1004 def constptrToSelfExpr(self): 1005 """|*constptrToSelfExpr()| has type |self.constType()|""" 1006 v = self.unionValue() 1007 if self.recursive: 1008 return v 1009 return ExprCall(ExprSelect(v, ".", "addr")) 1010 1011 def ptrToInternalType(self): 1012 t = self.ptrToType() 1013 if self.recursive: 1014 t.ref = True 1015 return t 1016 1017 def defaultValue(self, fq=False): 1018 # Use the default constructor for any class that does not have an 1019 # implicit copy constructor. 1020 if not self.bareType().hasimplicitcopyctor: 1021 return None 1022 1023 if self.ipdltype.isIPDL() and self.ipdltype.isActor(): 1024 return ExprLiteral.NULL 1025 # XXX sneaky here, maybe need ExprCtor()? 1026 return ExprCall(self.bareType(fq=fq)) 1027 1028 def getConstValue(self): 1029 v = ExprDeref(self.callGetConstPtr()) 1030 # sigh 1031 if "ByteBuf" == self.ipdltype.name(): 1032 v = ExprCast(v, Type("ByteBuf", ref=True), const=True) 1033 if "Shmem" == self.ipdltype.name(): 1034 v = ExprCast(v, Type("Shmem", ref=True), const=True) 1035 if "FileDescriptor" == self.ipdltype.name(): 1036 v = ExprCast(v, Type("FileDescriptor", ref=True), const=True) 1037 return v 1038 1039 1040# -------------------------------------------------- 1041 1042 1043class MessageDecl(ipdl.ast.MessageDecl): 1044 def baseName(self): 1045 return self.name 1046 1047 def recvMethod(self): 1048 name = _recvPrefix(self.decl.type) + self.baseName() 1049 if self.decl.type.isCtor(): 1050 name += "Constructor" 1051 return name 1052 1053 def sendMethod(self): 1054 name = _sendPrefix(self.decl.type) + self.baseName() 1055 if self.decl.type.isCtor(): 1056 name += "Constructor" 1057 return name 1058 1059 def hasReply(self): 1060 return ( 1061 self.decl.type.hasReply() 1062 or self.decl.type.isCtor() 1063 or self.decl.type.isDtor() 1064 ) 1065 1066 def hasAsyncReturns(self): 1067 return self.decl.type.isAsync() and self.returns 1068 1069 def msgCtorFunc(self): 1070 return "Msg_%s" % (self.decl.progname) 1071 1072 def prettyMsgName(self, pfx=""): 1073 return pfx + self.msgCtorFunc() 1074 1075 def pqMsgCtorFunc(self): 1076 return "%s::%s" % (self.namespace, self.msgCtorFunc()) 1077 1078 def msgId(self): 1079 return self.msgCtorFunc() + "__ID" 1080 1081 def pqMsgId(self): 1082 return "%s::%s" % (self.namespace, self.msgId()) 1083 1084 def replyCtorFunc(self): 1085 return "Reply_%s" % (self.decl.progname) 1086 1087 def pqReplyCtorFunc(self): 1088 return "%s::%s" % (self.namespace, self.replyCtorFunc()) 1089 1090 def replyId(self): 1091 return self.replyCtorFunc() + "__ID" 1092 1093 def pqReplyId(self): 1094 return "%s::%s" % (self.namespace, self.replyId()) 1095 1096 def prettyReplyName(self, pfx=""): 1097 return pfx + self.replyCtorFunc() 1098 1099 def promiseName(self): 1100 name = self.baseName() 1101 if self.decl.type.isCtor(): 1102 name += "Constructor" 1103 name += "Promise" 1104 return name 1105 1106 def resolverName(self): 1107 return self.baseName() + "Resolver" 1108 1109 def actorDecl(self): 1110 return self.params[0] 1111 1112 def makeCxxParams( 1113 self, paramsems="in", returnsems="out", side=None, implicit=True, direction=None 1114 ): 1115 """Return a list of C++ decls per the spec'd configuration. 1116 |params| and |returns| is the C++ semantics of those: 'in', 'out', or None.""" 1117 1118 def makeDecl(d, sems): 1119 if ( 1120 self.decl.type.tainted 1121 and "NoTaint" not in d.attributes 1122 and direction == "recv" 1123 ): 1124 # Tainted types are passed by-value, allowing the receiver to move them if desired. 1125 assert sems != "out" 1126 return Decl(Type("Tainted", T=d.bareType(side)), d.name) 1127 1128 if sems == "in": 1129 return Decl(d.inType(side), d.name) 1130 elif sems == "move": 1131 return Decl(d.moveType(side), d.name) 1132 elif sems == "out": 1133 return Decl(d.outType(side), d.name) 1134 else: 1135 assert 0 1136 1137 def makeResolverDecl(returns): 1138 return Decl(Type(self.resolverName(), rvalref=True), "aResolve") 1139 1140 def makeCallbackResolveDecl(returns): 1141 if len(returns) > 1: 1142 resolvetype = _tuple([d.bareType(side) for d in returns]) 1143 else: 1144 resolvetype = returns[0].bareType(side) 1145 1146 return Decl( 1147 Type("mozilla::ipc::ResolveCallback", T=resolvetype, rvalref=True), 1148 "aResolve", 1149 ) 1150 1151 def makeCallbackRejectDecl(returns): 1152 return Decl(Type("mozilla::ipc::RejectCallback", rvalref=True), "aReject") 1153 1154 cxxparams = [] 1155 if paramsems is not None: 1156 cxxparams.extend([makeDecl(d, paramsems) for d in self.params]) 1157 1158 if returnsems == "promise" and self.returns: 1159 pass 1160 elif returnsems == "callback" and self.returns: 1161 cxxparams.extend( 1162 [ 1163 makeCallbackResolveDecl(self.returns), 1164 makeCallbackRejectDecl(self.returns), 1165 ] 1166 ) 1167 elif returnsems == "resolver" and self.returns: 1168 cxxparams.extend([makeResolverDecl(self.returns)]) 1169 elif returnsems is not None: 1170 cxxparams.extend([makeDecl(r, returnsems) for r in self.returns]) 1171 1172 if not implicit and self.decl.type.hasImplicitActorParam(): 1173 cxxparams = cxxparams[1:] 1174 1175 return cxxparams 1176 1177 def makeCxxArgs( 1178 self, paramsems="in", retsems="out", retcallsems="out", implicit=True 1179 ): 1180 assert not retcallsems or retsems # retcallsems => returnsems 1181 cxxargs = [] 1182 1183 if paramsems == "move": 1184 # We don't std::move() RefPtr<T> types because current Recv*() 1185 # implementors take these parameters as T*, and 1186 # std::move(RefPtr<T>) doesn't coerce to T*. 1187 cxxargs.extend( 1188 [ 1189 p.var() if p.ipdltype.isRefcounted() else ExprMove(p.var()) 1190 for p in self.params 1191 ] 1192 ) 1193 elif paramsems == "in": 1194 cxxargs.extend([p.var() for p in self.params]) 1195 else: 1196 assert False 1197 1198 for ret in self.returns: 1199 if retsems == "in": 1200 if retcallsems == "in": 1201 cxxargs.append(ret.var()) 1202 elif retcallsems == "out": 1203 cxxargs.append(ExprAddrOf(ret.var())) 1204 else: 1205 assert 0 1206 elif retsems == "out": 1207 if retcallsems == "in": 1208 cxxargs.append(ExprDeref(ret.var())) 1209 elif retcallsems == "out": 1210 cxxargs.append(ret.var()) 1211 else: 1212 assert 0 1213 elif retsems == "resolver": 1214 pass 1215 if retsems == "resolver": 1216 cxxargs.append(ExprMove(ExprVar("resolver"))) 1217 1218 if not implicit: 1219 assert self.decl.type.hasImplicitActorParam() 1220 cxxargs = cxxargs[1:] 1221 1222 return cxxargs 1223 1224 @staticmethod 1225 def upgrade(messageDecl): 1226 assert isinstance(messageDecl, ipdl.ast.MessageDecl) 1227 if messageDecl.decl.type.hasImplicitActorParam(): 1228 messageDecl.params.insert( 1229 0, 1230 _HybridDecl( 1231 ipdl.type.ActorType(messageDecl.decl.type.constructedType()), 1232 "actor", 1233 ), 1234 ) 1235 messageDecl.__class__ = MessageDecl 1236 1237 1238# -------------------------------------------------- 1239def _usesShmem(p): 1240 for md in p.messageDecls: 1241 for param in md.inParams: 1242 if ipdl.type.hasshmem(param.type): 1243 return True 1244 for ret in md.outParams: 1245 if ipdl.type.hasshmem(ret.type): 1246 return True 1247 return False 1248 1249 1250def _subtreeUsesShmem(p): 1251 if _usesShmem(p): 1252 return True 1253 1254 ptype = p.decl.type 1255 for mgd in ptype.manages: 1256 if ptype is not mgd: 1257 if _subtreeUsesShmem(mgd._ast): 1258 return True 1259 return False 1260 1261 1262class Protocol(ipdl.ast.Protocol): 1263 def cxxTypedefs(self): 1264 return self.decl.cxxtypedefs 1265 1266 def managerInterfaceType(self, ptr=False): 1267 return Type("mozilla::ipc::IProtocol", ptr=ptr) 1268 1269 def openedProtocolInterfaceType(self, ptr=False): 1270 return Type("mozilla::ipc::IToplevelProtocol", ptr=ptr) 1271 1272 def _ipdlmgrtype(self): 1273 assert 1 == len(self.decl.type.managers) 1274 for mgr in self.decl.type.managers: 1275 return mgr 1276 1277 def managerActorType(self, side, ptr=False): 1278 return Type(_actorName(self._ipdlmgrtype().name(), side), ptr=ptr) 1279 1280 def unregisterMethod(self, actorThis=None): 1281 if actorThis is not None: 1282 return ExprSelect(actorThis, "->", "Unregister") 1283 return ExprVar("Unregister") 1284 1285 def removeManageeMethod(self): 1286 return ExprVar("RemoveManagee") 1287 1288 def deallocManageeMethod(self): 1289 return ExprVar("DeallocManagee") 1290 1291 def otherPidMethod(self): 1292 return ExprVar("OtherPid") 1293 1294 def callOtherPid(self, actorThis=None): 1295 fn = self.otherPidMethod() 1296 if actorThis is not None: 1297 fn = ExprSelect(actorThis, "->", fn.name) 1298 return ExprCall(fn) 1299 1300 def getChannelMethod(self): 1301 return ExprVar("GetIPCChannel") 1302 1303 def callGetChannel(self, actorThis=None): 1304 fn = self.getChannelMethod() 1305 if actorThis is not None: 1306 fn = ExprSelect(actorThis, "->", fn.name) 1307 return ExprCall(fn) 1308 1309 def processingErrorVar(self): 1310 assert self.decl.type.isToplevel() 1311 return ExprVar("ProcessingError") 1312 1313 def shouldContinueFromTimeoutVar(self): 1314 assert self.decl.type.isToplevel() 1315 return ExprVar("ShouldContinueFromReplyTimeout") 1316 1317 def enteredCxxStackVar(self): 1318 assert self.decl.type.isToplevel() 1319 return ExprVar("EnteredCxxStack") 1320 1321 def exitedCxxStackVar(self): 1322 assert self.decl.type.isToplevel() 1323 return ExprVar("ExitedCxxStack") 1324 1325 def enteredCallVar(self): 1326 assert self.decl.type.isToplevel() 1327 return ExprVar("EnteredCall") 1328 1329 def exitedCallVar(self): 1330 assert self.decl.type.isToplevel() 1331 return ExprVar("ExitedCall") 1332 1333 def routingId(self, actorThis=None): 1334 if self.decl.type.isToplevel(): 1335 return ExprVar("MSG_ROUTING_CONTROL") 1336 if actorThis is not None: 1337 return ExprCall(ExprSelect(actorThis, "->", "Id")) 1338 return ExprCall(ExprVar("Id")) 1339 1340 def managerVar(self, thisexpr=None): 1341 assert thisexpr is not None or not self.decl.type.isToplevel() 1342 mvar = ExprCall(ExprVar("Manager"), args=[]) 1343 if thisexpr is not None: 1344 mvar = ExprCall(ExprSelect(thisexpr, "->", "Manager"), args=[]) 1345 return mvar 1346 1347 def managedCxxType(self, actortype, side): 1348 assert self.decl.type.isManagerOf(actortype) 1349 return Type(_actorName(actortype.name(), side), ptr=True) 1350 1351 def managedMethod(self, actortype, side): 1352 assert self.decl.type.isManagerOf(actortype) 1353 return ExprVar("Managed" + _actorName(actortype.name(), side)) 1354 1355 def managedVar(self, actortype, side): 1356 assert self.decl.type.isManagerOf(actortype) 1357 return ExprVar("mManaged" + _actorName(actortype.name(), side)) 1358 1359 def managedVarType(self, actortype, side, const=False, ref=False): 1360 assert self.decl.type.isManagerOf(actortype) 1361 return _cxxManagedContainerType( 1362 Type(_actorName(actortype.name(), side)), const=const, ref=ref 1363 ) 1364 1365 def subtreeUsesShmem(self): 1366 return _subtreeUsesShmem(self) 1367 1368 @staticmethod 1369 def upgrade(protocol): 1370 assert isinstance(protocol, ipdl.ast.Protocol) 1371 protocol.__class__ = Protocol 1372 1373 1374class TranslationUnit(ipdl.ast.TranslationUnit): 1375 @staticmethod 1376 def upgrade(tu): 1377 assert isinstance(tu, ipdl.ast.TranslationUnit) 1378 tu.__class__ = TranslationUnit 1379 1380 1381# ----------------------------------------------------------------------------- 1382 1383pod_types = { 1384 "int8_t": 1, 1385 "uint8_t": 1, 1386 "int16_t": 2, 1387 "uint16_t": 2, 1388 "int32_t": 4, 1389 "uint32_t": 4, 1390 "int64_t": 8, 1391 "uint64_t": 8, 1392 "float": 4, 1393 "double": 8, 1394} 1395max_pod_size = max(pod_types.values()) 1396# We claim that all types we don't recognize are automatically "bigger" 1397# than pod types for ease of sorting. 1398pod_size_sentinel = max_pod_size * 2 1399 1400 1401def pod_size(ipdltype): 1402 if not isinstance(ipdltype, ipdl.type.ImportedCxxType): 1403 return pod_size_sentinel 1404 1405 return pod_types.get(ipdltype.name(), pod_size_sentinel) 1406 1407 1408class _DecorateWithCxxStuff(ipdl.ast.Visitor): 1409 """Phase 1 of lowering: decorate the IPDL AST with information 1410 relevant to C++ code generation. 1411 1412 This pass results in an AST that is a poor man's "IR"; in reality, a 1413 "hybrid" AST mainly consisting of IPDL nodes with new C++ info along 1414 with some new IPDL/C++ nodes that are tuned for C++ codegen.""" 1415 1416 def __init__(self): 1417 self.visitedTus = set() 1418 # the set of typedefs that allow generated classes to 1419 # reference known C++ types by their "short name" rather than 1420 # fully-qualified name. e.g. |Foo| rather than |a::b::Foo|. 1421 self.typedefs = [] 1422 self.typedefSet = set( 1423 [ 1424 Typedef(Type("mozilla::ipc::ActorHandle"), "ActorHandle"), 1425 Typedef(Type("base::ProcessId"), "ProcessId"), 1426 Typedef(Type("mozilla::ipc::ProtocolId"), "ProtocolId"), 1427 Typedef(Type("mozilla::ipc::Transport"), "Transport"), 1428 Typedef(Type("mozilla::ipc::Endpoint"), "Endpoint", ["FooSide"]), 1429 Typedef( 1430 Type("mozilla::ipc::ManagedEndpoint"), 1431 "ManagedEndpoint", 1432 ["FooSide"], 1433 ), 1434 Typedef( 1435 Type("mozilla::ipc::TransportDescriptor"), "TransportDescriptor" 1436 ), 1437 Typedef(Type("mozilla::UniquePtr"), "UniquePtr", ["T"]), 1438 Typedef( 1439 Type("mozilla::ipc::ResponseRejectReason"), "ResponseRejectReason" 1440 ), 1441 ] 1442 ) 1443 self.protocolName = None 1444 1445 def visitTranslationUnit(self, tu): 1446 if tu not in self.visitedTus: 1447 self.visitedTus.add(tu) 1448 ipdl.ast.Visitor.visitTranslationUnit(self, tu) 1449 if not isinstance(tu, TranslationUnit): 1450 TranslationUnit.upgrade(tu) 1451 self.typedefs[:] = sorted(list(self.typedefSet)) 1452 1453 def visitInclude(self, inc): 1454 if inc.tu.filetype == "header": 1455 inc.tu.accept(self) 1456 1457 def visitProtocol(self, pro): 1458 self.protocolName = pro.name 1459 pro.decl.cxxtypedefs = self.typedefs 1460 Protocol.upgrade(pro) 1461 return ipdl.ast.Visitor.visitProtocol(self, pro) 1462 1463 def visitUsingStmt(self, using): 1464 if using.decl.fullname is not None: 1465 self.typedefSet.add( 1466 Typedef(Type(using.decl.fullname), using.decl.shortname) 1467 ) 1468 1469 def visitStructDecl(self, sd): 1470 if not isinstance(sd, StructDecl): 1471 sd.decl.special = False 1472 newfields = [] 1473 for f in sd.fields: 1474 ftype = f.decl.type 1475 if _hasVisibleActor(ftype): 1476 sd.decl.special = True 1477 # if ftype has a visible actor, we need both 1478 # |ActorParent| and |ActorChild| fields 1479 newfields.append(_StructField(ftype, f.name, sd, side="parent")) 1480 newfields.append(_StructField(ftype, f.name, sd, side="child")) 1481 else: 1482 newfields.append(_StructField(ftype, f.name, sd)) 1483 1484 # Compute a permutation of the fields for in-memory storage such 1485 # that the memory layout of the structure will be well-packed. 1486 permutation = list(range(len(newfields))) 1487 1488 # Note that the results of `pod_size` ensure that non-POD fields 1489 # sort before POD ones. 1490 def size(idx): 1491 return pod_size(newfields[idx].ipdltype) 1492 1493 permutation.sort(key=size, reverse=True) 1494 1495 sd.fields = newfields 1496 sd.packed_field_order = permutation 1497 StructDecl.upgrade(sd) 1498 1499 if sd.decl.fullname is not None: 1500 self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name)) 1501 1502 def visitUnionDecl(self, ud): 1503 ud.decl.special = False 1504 newcomponents = [] 1505 for ctype in ud.decl.type.components: 1506 if _hasVisibleActor(ctype): 1507 ud.decl.special = True 1508 # if ctype has a visible actor, we need both 1509 # |ActorParent| and |ActorChild| union members 1510 newcomponents.append(_UnionMember(ctype, ud, side="parent")) 1511 newcomponents.append(_UnionMember(ctype, ud, side="child")) 1512 else: 1513 newcomponents.append(_UnionMember(ctype, ud)) 1514 ud.components = newcomponents 1515 UnionDecl.upgrade(ud) 1516 1517 if ud.decl.fullname is not None: 1518 self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name)) 1519 1520 def visitDecl(self, decl): 1521 return _HybridDecl(decl.type, decl.progname, decl.attributes) 1522 1523 def visitMessageDecl(self, md): 1524 md.namespace = self.protocolName 1525 md.params = [param.accept(self) for param in md.inParams] 1526 md.returns = [ret.accept(self) for ret in md.outParams] 1527 MessageDecl.upgrade(md) 1528 1529 1530# ----------------------------------------------------------------------------- 1531 1532 1533def msgenums(protocol, pretty=False): 1534 msgenum = TypeEnum("MessageType") 1535 msgstart = _messageStartName(protocol.decl.type) + " << 16" 1536 msgenum.addId(protocol.name + "Start", msgstart) 1537 1538 for md in protocol.messageDecls: 1539 msgenum.addId(md.prettyMsgName() if pretty else md.msgId()) 1540 if md.hasReply(): 1541 msgenum.addId(md.prettyReplyName() if pretty else md.replyId()) 1542 1543 msgenum.addId(protocol.name + "End") 1544 return msgenum 1545 1546 1547class _GenerateProtocolCode(ipdl.ast.Visitor): 1548 """Creates code common to both the parent and child actors.""" 1549 1550 def __init__(self): 1551 self.protocol = None # protocol we're generating a class for 1552 self.hdrfile = None # what will become Protocol.h 1553 self.cppfile = None # what will become Protocol.cpp 1554 self.cppIncludeHeaders = [] 1555 self.structUnionDefns = [] 1556 self.funcDefns = [] 1557 1558 def lower(self, tu, cxxHeaderFile, cxxFile, segmentcapacitydict): 1559 self.protocol = tu.protocol 1560 self.hdrfile = cxxHeaderFile 1561 self.cppfile = cxxFile 1562 self.segmentcapacitydict = segmentcapacitydict 1563 tu.accept(self) 1564 1565 def visitTranslationUnit(self, tu): 1566 hf = self.hdrfile 1567 1568 hf.addthing(_DISCLAIMER) 1569 hf.addthings(_includeGuardStart(hf)) 1570 hf.addthing(Whitespace.NL) 1571 1572 for inc in builtinHeaderIncludes: 1573 self.visitBuiltinCxxInclude(inc) 1574 1575 # Compute the set of includes we need for declared structure/union 1576 # classes for this protocol. 1577 typesToIncludes = {} 1578 for using in tu.using: 1579 typestr = str(using.type.spec) 1580 if typestr not in typesToIncludes: 1581 typesToIncludes[typestr] = using.header 1582 else: 1583 assert typesToIncludes[typestr] == using.header 1584 1585 aggregateTypeIncludes = set() 1586 for su in tu.structsAndUnions: 1587 typedeps = _ComputeTypeDeps(su.decl.type, True) 1588 if isinstance(su, ipdl.ast.StructDecl): 1589 for f in su.fields: 1590 f.ipdltype.accept(typedeps) 1591 elif isinstance(su, ipdl.ast.UnionDecl): 1592 for c in su.components: 1593 c.ipdltype.accept(typedeps) 1594 1595 for typename in [t.fromtype.name for t in typedeps.usingTypedefs]: 1596 if typename in typesToIncludes: 1597 aggregateTypeIncludes.add(typesToIncludes[typename]) 1598 1599 if len(aggregateTypeIncludes) != 0: 1600 hf.addthing(Whitespace.NL) 1601 hf.addthings([Whitespace("// Headers for typedefs"), Whitespace.NL]) 1602 1603 for headername in sorted(iter(aggregateTypeIncludes)): 1604 hf.addthing(CppDirective("include", '"' + headername + '"')) 1605 1606 # Manually run Visitor.visitTranslationUnit. For dependency resolution 1607 # we need to handle structs and unions separately. 1608 for cxxInc in tu.cxxIncludes: 1609 cxxInc.accept(self) 1610 for inc in tu.includes: 1611 inc.accept(self) 1612 self.generateStructsAndUnions(tu) 1613 for using in tu.builtinUsing: 1614 using.accept(self) 1615 for using in tu.using: 1616 using.accept(self) 1617 if tu.protocol: 1618 tu.protocol.accept(self) 1619 1620 if tu.filetype == "header": 1621 self.cppIncludeHeaders.append(_ipdlhHeaderName(tu) + ".h") 1622 1623 hf.addthing(Whitespace.NL) 1624 hf.addthings(_includeGuardEnd(hf)) 1625 1626 cf = self.cppfile 1627 cf.addthings( 1628 ( 1629 [_DISCLAIMER, Whitespace.NL] 1630 + [ 1631 CppDirective("include", '"' + h + '"') 1632 for h in self.cppIncludeHeaders 1633 ] 1634 + [Whitespace.NL] 1635 + [ 1636 CppDirective("include", '"%s"' % filename) 1637 for filename in ipdl.builtin.CppIncludes 1638 ] 1639 + [Whitespace.NL] 1640 ) 1641 ) 1642 1643 if self.protocol: 1644 # construct the namespace into which we'll stick all our defns 1645 ns = Namespace(self.protocol.name) 1646 cf.addthing(_putInNamespaces(ns, self.protocol.namespaces)) 1647 ns.addstmts(([Whitespace.NL] + self.funcDefns + [Whitespace.NL])) 1648 1649 cf.addthings(self.structUnionDefns) 1650 1651 def visitBuiltinCxxInclude(self, inc): 1652 self.hdrfile.addthing(CppDirective("include", '"' + inc.file + '"')) 1653 1654 def visitCxxInclude(self, inc): 1655 self.cppIncludeHeaders.append(inc.file) 1656 1657 def visitInclude(self, inc): 1658 if inc.tu.filetype == "header": 1659 self.hdrfile.addthing( 1660 CppDirective("include", '"' + _ipdlhHeaderName(inc.tu) + '.h"') 1661 ) 1662 else: 1663 self.cppIncludeHeaders += [ 1664 _protocolHeaderName(inc.tu.protocol, "parent") + ".h", 1665 _protocolHeaderName(inc.tu.protocol, "child") + ".h", 1666 ] 1667 1668 def generateStructsAndUnions(self, tu): 1669 """Generate the definitions for all structs and unions. This will 1670 re-order the declarations if needed in the C++ code such that 1671 dependencies have already been defined.""" 1672 decls = OrderedDict() 1673 for su in tu.structsAndUnions: 1674 if isinstance(su, StructDecl): 1675 which = "struct" 1676 forwarddecls, fulldecltypes, cls = _generateCxxStruct(su) 1677 traitsdecl, traitsdefns = _ParamTraits.structPickling(su.decl.type) 1678 else: 1679 assert isinstance(su, UnionDecl) 1680 which = "union" 1681 forwarddecls, fulldecltypes, cls = _generateCxxUnion(su) 1682 traitsdecl, traitsdefns = _ParamTraits.unionPickling(su.decl.type) 1683 1684 clsdecl, methoddefns = _splitClassDeclDefn(cls) 1685 1686 # Store the declarations in the decls map so we can emit in 1687 # dependency order. 1688 decls[su.decl.type] = ( 1689 fulldecltypes, 1690 [Whitespace.NL] 1691 + forwarddecls 1692 + [ 1693 Whitespace( 1694 """ 1695//----------------------------------------------------------------------------- 1696// Declaration of the IPDL type |%s %s| 1697// 1698""" 1699 % (which, su.name) 1700 ), 1701 _putInNamespaces(clsdecl, su.namespaces), 1702 ] 1703 + [Whitespace.NL, traitsdecl], 1704 ) 1705 1706 self.structUnionDefns.extend( 1707 [ 1708 Whitespace( 1709 """ 1710//----------------------------------------------------------------------------- 1711// Method definitions for the IPDL type |%s %s| 1712// 1713""" 1714 % (which, su.name) 1715 ), 1716 _putInNamespaces(methoddefns, su.namespaces), 1717 Whitespace.NL, 1718 traitsdefns, 1719 ] 1720 ) 1721 1722 # Generate the declarations structs in dependency order. 1723 def gen_struct(deps, defn): 1724 for dep in deps: 1725 if dep in decls: 1726 d, t = decls[dep] 1727 del decls[dep] 1728 gen_struct(d, t) 1729 self.hdrfile.addthings(defn) 1730 1731 while len(decls) > 0: 1732 _, (d, t) = decls.popitem(False) 1733 gen_struct(d, t) 1734 1735 def visitProtocol(self, p): 1736 self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, "") + ".h") 1737 self.cppIncludeHeaders.append( 1738 _protocolHeaderName(self.protocol, "Parent") + ".h" 1739 ) 1740 self.cppIncludeHeaders.append( 1741 _protocolHeaderName(self.protocol, "Child") + ".h" 1742 ) 1743 1744 # Forward declare our own actors. 1745 self.hdrfile.addthings( 1746 [ 1747 Whitespace.NL, 1748 _makeForwardDeclForActor(p.decl.type, "Parent"), 1749 _makeForwardDeclForActor(p.decl.type, "Child"), 1750 ] 1751 ) 1752 1753 self.hdrfile.addthing( 1754 Whitespace( 1755 """ 1756//----------------------------------------------------------------------------- 1757// Code common to %sChild and %sParent 1758// 1759""" 1760 % (p.name, p.name) 1761 ) 1762 ) 1763 1764 # construct the namespace into which we'll stick all our decls 1765 ns = Namespace(self.protocol.name) 1766 self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces)) 1767 ns.addstmt(Whitespace.NL) 1768 1769 edecl, edefn = _splitFuncDeclDefn(self.genEndpointFunc()) 1770 ns.addstmts([edecl, Whitespace.NL]) 1771 self.funcDefns.append(edefn) 1772 1773 # spit out message type enum and classes 1774 msgenum = msgenums(self.protocol) 1775 ns.addstmts([StmtDecl(Decl(msgenum, "")), Whitespace.NL]) 1776 1777 for md in p.messageDecls: 1778 decls = [] 1779 1780 # Look up the segment capacity used for serializing this 1781 # message. If the capacity is not specified, use '0' for 1782 # the default capacity (defined in ipc_message.cc) 1783 name = "%s::%s" % (md.namespace, md.decl.progname) 1784 segmentcapacity = self.segmentcapacitydict.get(name, 0) 1785 1786 mfDecl, mfDefn = _splitFuncDeclDefn( 1787 _generateMessageConstructor(md, segmentcapacity, p, forReply=False) 1788 ) 1789 decls.append(mfDecl) 1790 self.funcDefns.append(mfDefn) 1791 1792 if md.hasReply(): 1793 rfDecl, rfDefn = _splitFuncDeclDefn( 1794 _generateMessageConstructor(md, 0, p, forReply=True) 1795 ) 1796 decls.append(rfDecl) 1797 self.funcDefns.append(rfDefn) 1798 1799 decls.append(Whitespace.NL) 1800 ns.addstmts(decls) 1801 1802 ns.addstmts([Whitespace.NL, Whitespace.NL]) 1803 1804 # Generate code for PFoo::CreateEndpoints. 1805 def genEndpointFunc(self): 1806 p = self.protocol.decl.type 1807 tparent = _cxxBareType(ActorType(p), "Parent", fq=True) 1808 tchild = _cxxBareType(ActorType(p), "Child", fq=True) 1809 1810 openfunc = MethodDefn( 1811 MethodDecl( 1812 "CreateEndpoints", 1813 params=[ 1814 Decl(Type("base::ProcessId"), "aParentDestPid"), 1815 Decl(Type("base::ProcessId"), "aChildDestPid"), 1816 Decl( 1817 Type("mozilla::ipc::Endpoint<" + tparent.name + ">", ptr=True), 1818 "aParent", 1819 ), 1820 Decl( 1821 Type("mozilla::ipc::Endpoint<" + tchild.name + ">", ptr=True), 1822 "aChild", 1823 ), 1824 ], 1825 ret=Type.NSRESULT, 1826 ) 1827 ) 1828 openfunc.addcode( 1829 """ 1830 return mozilla::ipc::CreateEndpoints( 1831 mozilla::ipc::PrivateIPDLInterface(), 1832 aParentDestPid, aChildDestPid, 1833 aParent, aChild); 1834 """ 1835 ) 1836 return openfunc 1837 1838 1839# -------------------------------------------------- 1840 1841 1842def _generateMessageConstructor(md, segmentSize, protocol, forReply=False): 1843 if forReply: 1844 clsname = md.replyCtorFunc() 1845 msgid = md.replyId() 1846 replyEnum = "REPLY" 1847 else: 1848 clsname = md.msgCtorFunc() 1849 msgid = md.msgId() 1850 replyEnum = "NOT_REPLY" 1851 1852 nested = md.decl.type.nested 1853 prio = md.decl.type.prio 1854 compress = md.decl.type.compress 1855 1856 routingId = ExprVar("routingId") 1857 1858 func = FunctionDefn( 1859 FunctionDecl( 1860 clsname, 1861 params=[Decl(Type("int32_t"), routingId.name)], 1862 ret=Type("IPC::Message", ptr=True), 1863 ) 1864 ) 1865 1866 if not compress: 1867 compression = "COMPRESSION_NONE" 1868 elif compress.value == "all": 1869 compression = "COMPRESSION_ALL" 1870 else: 1871 assert compress.value is None 1872 compression = "COMPRESSION_ENABLED" 1873 1874 if nested == ipdl.ast.NOT_NESTED: 1875 nestedEnum = "NOT_NESTED" 1876 elif nested == ipdl.ast.INSIDE_SYNC_NESTED: 1877 nestedEnum = "NESTED_INSIDE_SYNC" 1878 else: 1879 assert nested == ipdl.ast.INSIDE_CPOW_NESTED 1880 nestedEnum = "NESTED_INSIDE_CPOW" 1881 1882 if prio == ipdl.ast.NORMAL_PRIORITY: 1883 prioEnum = "NORMAL_PRIORITY" 1884 elif prio == ipdl.ast.INPUT_PRIORITY: 1885 prioEnum = "INPUT_PRIORITY" 1886 elif prio == ipdl.ast.VSYNC_PRIORITY: 1887 prioEnum = "VSYNC_PRIORITY" 1888 elif prio == ipdl.ast.MEDIUMHIGH_PRIORITY: 1889 prioEnum = "MEDIUMHIGH_PRIORITY" 1890 else: 1891 prioEnum = "CONTROL_PRIORITY" 1892 1893 if md.decl.type.isSync(): 1894 syncEnum = "SYNC" 1895 else: 1896 syncEnum = "ASYNC" 1897 1898 if md.decl.type.isInterrupt(): 1899 interruptEnum = "INTERRUPT" 1900 else: 1901 interruptEnum = "NOT_INTERRUPT" 1902 1903 if md.decl.type.isCtor(): 1904 ctorEnum = "CONSTRUCTOR" 1905 else: 1906 ctorEnum = "NOT_CONSTRUCTOR" 1907 1908 def messageEnum(valname): 1909 return ExprVar("IPC::Message::" + valname) 1910 1911 flags = ExprCall( 1912 ExprVar("IPC::Message::HeaderFlags"), 1913 args=[ 1914 messageEnum(nestedEnum), 1915 messageEnum(prioEnum), 1916 messageEnum(compression), 1917 messageEnum(ctorEnum), 1918 messageEnum(syncEnum), 1919 messageEnum(interruptEnum), 1920 messageEnum(replyEnum), 1921 ], 1922 ) 1923 1924 segmentSize = int(segmentSize) 1925 if segmentSize: 1926 func.addstmt( 1927 StmtReturn( 1928 ExprNew( 1929 Type("IPC::Message"), 1930 args=[ 1931 routingId, 1932 ExprVar(msgid), 1933 ExprLiteral.Int(int(segmentSize)), 1934 flags, 1935 # Pass `true` to recordWriteLatency to collect telemetry 1936 ExprLiteral.TRUE, 1937 ], 1938 ) 1939 ) 1940 ) 1941 else: 1942 func.addstmt( 1943 StmtReturn( 1944 ExprCall( 1945 ExprVar("IPC::Message::IPDLMessage"), 1946 args=[routingId, ExprVar(msgid), flags], 1947 ) 1948 ) 1949 ) 1950 1951 return func 1952 1953 1954# -------------------------------------------------- 1955 1956 1957class _ParamTraits: 1958 var = ExprVar("aVar") 1959 msgvar = ExprVar("aMsg") 1960 itervar = ExprVar("aIter") 1961 actor = ExprVar("aActor") 1962 1963 @classmethod 1964 def ifsideis(cls, side, then, els=None): 1965 cxxside = ExprVar("mozilla::ipc::ChildSide") 1966 if side == "parent": 1967 cxxside = ExprVar("mozilla::ipc::ParentSide") 1968 1969 ifstmt = StmtIf( 1970 ExprBinary(cxxside, "==", ExprCall(ExprSelect(cls.actor, "->", "GetSide"))) 1971 ) 1972 ifstmt.addifstmt(then) 1973 if els is not None: 1974 ifstmt.addelsestmt(els) 1975 return ifstmt 1976 1977 @classmethod 1978 def fatalError(cls, reason): 1979 return StmtCode( 1980 "aActor->FatalError(${reason});", reason=ExprLiteral.String(reason) 1981 ) 1982 1983 @classmethod 1984 def writeSentinel(cls, msgvar, sentinelKey): 1985 return [ 1986 Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True), 1987 StmtExpr( 1988 ExprCall( 1989 ExprSelect(msgvar, "->", "WriteSentinel"), 1990 args=[ExprLiteral.Int(hashfunc(sentinelKey))], 1991 ) 1992 ), 1993 ] 1994 1995 @classmethod 1996 def readSentinel(cls, msgvar, itervar, sentinelKey, sentinelFail): 1997 # Read the sentinel 1998 read = ExprCall( 1999 ExprSelect(msgvar, "->", "ReadSentinel"), 2000 args=[itervar, ExprLiteral.Int(hashfunc(sentinelKey))], 2001 ) 2002 ifsentinel = StmtIf(ExprNot(read)) 2003 ifsentinel.addifstmts(sentinelFail) 2004 2005 return [ 2006 Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True), 2007 ifsentinel, 2008 ] 2009 2010 @classmethod 2011 def write(cls, var, msgvar, actor, ipdltype=None): 2012 # WARNING: This doesn't set AutoForActor for you, make sure this is 2013 # only called when the actor is already correctly set. 2014 if ipdltype and _cxxTypeNeedsMoveForSend(ipdltype): 2015 var = ExprMove(var) 2016 return ExprCall(ExprVar("WriteIPDLParam"), args=[msgvar, actor, var]) 2017 2018 @classmethod 2019 def checkedWrite(cls, ipdltype, var, msgvar, sentinelKey, actor): 2020 assert sentinelKey 2021 block = Block() 2022 2023 # Assert we aren't serializing a null non-nullable actor 2024 if ( 2025 ipdltype 2026 and ipdltype.isIPDL() 2027 and ipdltype.isActor() 2028 and not ipdltype.nullable 2029 ): 2030 block.addstmt( 2031 _abortIfFalse(var, "NULL actor value passed to non-nullable param") 2032 ) 2033 2034 block.addstmts( 2035 [ 2036 StmtExpr(cls.write(var, msgvar, actor, ipdltype)), 2037 ] 2038 ) 2039 block.addstmts(cls.writeSentinel(msgvar, sentinelKey)) 2040 return block 2041 2042 @classmethod 2043 def bulkSentinelKey(cls, fields): 2044 return " | ".join(f.basename for f in fields) 2045 2046 @classmethod 2047 def checkedBulkWrite(cls, size, fields): 2048 block = Block() 2049 first = fields[0] 2050 2051 block.addstmts( 2052 [ 2053 StmtExpr( 2054 ExprCall( 2055 ExprSelect(cls.msgvar, "->", "WriteBytes"), 2056 args=[ 2057 ExprAddrOf( 2058 ExprCall(first.getMethod(thisexpr=cls.var, sel=".")) 2059 ), 2060 ExprLiteral.Int(size * len(fields)), 2061 ], 2062 ) 2063 ) 2064 ] 2065 ) 2066 block.addstmts(cls.writeSentinel(cls.msgvar, cls.bulkSentinelKey(fields))) 2067 2068 return block 2069 2070 @classmethod 2071 def checkedBulkRead(cls, size, fields): 2072 block = Block() 2073 first = fields[0] 2074 2075 readbytes = ExprCall( 2076 ExprSelect(cls.msgvar, "->", "ReadBytesInto"), 2077 args=[ 2078 cls.itervar, 2079 ExprAddrOf(ExprCall(first.getMethod(thisexpr=cls.var, sel="->"))), 2080 ExprLiteral.Int(size * len(fields)), 2081 ], 2082 ) 2083 ifbad = StmtIf(ExprNot(readbytes)) 2084 errmsg = "Error bulk reading fields from %s" % first.ipdltype.name() 2085 ifbad.addifstmts([cls.fatalError(errmsg), StmtReturn.FALSE]) 2086 block.addstmt(ifbad) 2087 block.addstmts( 2088 cls.readSentinel( 2089 cls.msgvar, 2090 cls.itervar, 2091 cls.bulkSentinelKey(fields), 2092 errfnSentinel()(errmsg), 2093 ) 2094 ) 2095 2096 return block 2097 2098 @classmethod 2099 def checkedRead( 2100 cls, 2101 ipdltype, 2102 var, 2103 msgvar, 2104 itervar, 2105 errfn, 2106 paramtype, 2107 sentinelKey, 2108 errfnSentinel, 2109 actor, 2110 ): 2111 block = Block() 2112 2113 # Read the data 2114 ifbad = StmtIf( 2115 ExprNot( 2116 ExprCall(ExprVar("ReadIPDLParam"), args=[msgvar, itervar, actor, var]) 2117 ) 2118 ) 2119 if not isinstance(paramtype, list): 2120 paramtype = ["Error deserializing " + paramtype] 2121 ifbad.addifstmts(errfn(*paramtype)) 2122 block.addstmt(ifbad) 2123 2124 # Check if we got a null non-nullable actor 2125 if ( 2126 ipdltype 2127 and ipdltype.isIPDL() 2128 and ipdltype.isActor() 2129 and not ipdltype.nullable 2130 ): 2131 ifnull = StmtIf(ExprNot(ExprDeref(var))) 2132 ifnull.addifstmts(errfn(*paramtype)) 2133 block.addstmt(ifnull) 2134 2135 block.addstmts( 2136 cls.readSentinel(msgvar, itervar, sentinelKey, errfnSentinel(*paramtype)) 2137 ) 2138 2139 return block 2140 2141 # Helper wrapper for checkedRead for use within _ParamTraits 2142 @classmethod 2143 def _checkedRead(cls, ipdltype, var, sentinelKey, what): 2144 def errfn(msg): 2145 return [cls.fatalError(msg), StmtReturn.FALSE] 2146 2147 return cls.checkedRead( 2148 ipdltype, 2149 var, 2150 cls.msgvar, 2151 cls.itervar, 2152 errfn=errfn, 2153 paramtype=what, 2154 sentinelKey=sentinelKey, 2155 errfnSentinel=errfnSentinel(), 2156 actor=cls.actor, 2157 ) 2158 2159 @classmethod 2160 def generateDecl(cls, fortype, write, read, constin=True): 2161 # IPDLParamTraits impls are selected ignoring constness, and references. 2162 pt = Class( 2163 "IPDLParamTraits", 2164 specializes=Type( 2165 fortype.name, T=fortype.T, inner=fortype.inner, ptr=fortype.ptr 2166 ), 2167 struct=True, 2168 ) 2169 2170 # typedef T paramType; 2171 pt.addstmt(Typedef(fortype, "paramType")) 2172 2173 iprotocoltype = Type("mozilla::ipc::IProtocol", ptr=True) 2174 2175 # static void Write(Message*, const T&); 2176 intype = Type("paramType", ref=True, const=constin) 2177 writemthd = MethodDefn( 2178 MethodDecl( 2179 "Write", 2180 params=[ 2181 Decl(Type("IPC::Message", ptr=True), cls.msgvar.name), 2182 Decl(iprotocoltype, cls.actor.name), 2183 Decl(intype, cls.var.name), 2184 ], 2185 methodspec=MethodSpec.STATIC, 2186 ) 2187 ) 2188 writemthd.addstmts(write) 2189 pt.addstmt(writemthd) 2190 2191 # static bool Read(const Message*, PickleIterator*, T*); 2192 outtype = Type("paramType", ptr=True) 2193 readmthd = MethodDefn( 2194 MethodDecl( 2195 "Read", 2196 params=[ 2197 Decl(Type("IPC::Message", ptr=True, const=True), cls.msgvar.name), 2198 Decl(_iterType(ptr=True), cls.itervar.name), 2199 Decl(iprotocoltype, cls.actor.name), 2200 Decl(outtype, cls.var.name), 2201 ], 2202 ret=Type.BOOL, 2203 methodspec=MethodSpec.STATIC, 2204 ) 2205 ) 2206 readmthd.addstmts(read) 2207 pt.addstmt(readmthd) 2208 2209 # Split the class into declaration and definition 2210 clsdecl, methoddefns = _splitClassDeclDefn(pt) 2211 2212 namespaces = [Namespace("mozilla"), Namespace("ipc")] 2213 clsns = _putInNamespaces(clsdecl, namespaces) 2214 defns = _putInNamespaces(methoddefns, namespaces) 2215 return clsns, defns 2216 2217 @classmethod 2218 def actorPickling(cls, actortype, side): 2219 """Generates pickling for IPDL actors. This is a |nullable| deserializer. 2220 Write and read callers will perform nullability validation.""" 2221 2222 cxxtype = _cxxBareType(actortype, side, fq=True) 2223 2224 write = StmtCode( 2225 """ 2226 int32_t id; 2227 if (!${var}) { 2228 id = 0; // kNullActorId 2229 } else { 2230 id = ${var}->Id(); 2231 if (id == 1) { // kFreedActorId 2232 ${var}->FatalError("Actor has been |delete|d"); 2233 } 2234 MOZ_RELEASE_ASSERT( 2235 ${actor}->GetIPCChannel() == ${var}->GetIPCChannel(), 2236 "Actor must be from the same channel as the" 2237 " actor it's being sent over"); 2238 MOZ_RELEASE_ASSERT( 2239 ${var}->CanSend(), 2240 "Actor must still be open when sending"); 2241 } 2242 2243 ${write}; 2244 """, 2245 var=cls.var, 2246 actor=cls.actor, 2247 write=cls.write(ExprVar("id"), cls.msgvar, cls.actor), 2248 ) 2249 2250 # bool Read(..) impl 2251 read = StmtCode( 2252 """ 2253 mozilla::Maybe<mozilla::ipc::IProtocol*> actor = 2254 ${actor}->ReadActor(${msgvar}, ${itervar}, true, ${actortype}, ${protocolid}); 2255 if (actor.isNothing()) { 2256 return false; 2257 } 2258 2259 *${var} = static_cast<${cxxtype}>(actor.value()); 2260 return true; 2261 """, 2262 actor=cls.actor, 2263 msgvar=cls.msgvar, 2264 itervar=cls.itervar, 2265 actortype=ExprLiteral.String(actortype.name()), 2266 protocolid=_protocolId(actortype), 2267 var=cls.var, 2268 cxxtype=cxxtype, 2269 ) 2270 2271 return cls.generateDecl(cxxtype, [write], [read]) 2272 2273 @classmethod 2274 def structPickling(cls, structtype): 2275 sd = structtype._ast 2276 # NOTE: Not using _cxxBareType here as we don't have a side 2277 cxxtype = Type(structtype.fullname()) 2278 2279 def get(sel, f): 2280 return ExprCall(f.getMethod(thisexpr=cls.var, sel=sel)) 2281 2282 write = [] 2283 read = [] 2284 2285 for (size, fields) in itertools.groupby( 2286 sd.fields_member_order(), lambda f: pod_size(f.ipdltype) 2287 ): 2288 fields = list(fields) 2289 2290 if size == pod_size_sentinel: 2291 for f in fields: 2292 writefield = cls.checkedWrite( 2293 f.ipdltype, 2294 get(".", f), 2295 cls.msgvar, 2296 sentinelKey=f.basename, 2297 actor=cls.actor, 2298 ) 2299 readfield = cls._checkedRead( 2300 f.ipdltype, 2301 ExprAddrOf(get("->", f)), 2302 f.basename, 2303 "'" 2304 + f.getMethod().name 2305 + "' " 2306 + "(" 2307 + f.ipdltype.name() 2308 + ") member of " 2309 + "'" 2310 + structtype.name() 2311 + "'", 2312 ) 2313 2314 # Wrap the read/write in a side check if the field is special. 2315 if f.special: 2316 writefield = cls.ifsideis(f.side, writefield) 2317 readfield = cls.ifsideis(f.side, readfield) 2318 2319 write.append(writefield) 2320 read.append(readfield) 2321 else: 2322 for f in fields: 2323 assert not f.special 2324 2325 writefield = cls.checkedBulkWrite(size, fields) 2326 readfield = cls.checkedBulkRead(size, fields) 2327 2328 write.append(writefield) 2329 read.append(readfield) 2330 2331 read.append(StmtReturn.TRUE) 2332 2333 return cls.generateDecl(cxxtype, write, read) 2334 2335 @classmethod 2336 def unionPickling(cls, uniontype): 2337 # NOTE: Not using _cxxBareType here as we don't have a side 2338 cxxtype = Type(uniontype.fullname()) 2339 ud = uniontype._ast 2340 2341 # Use typedef to set up an alias so it's easier to reference the struct type. 2342 alias = "union__" 2343 typevar = ExprVar("type") 2344 2345 prelude = [ 2346 Typedef(cxxtype, alias), 2347 ] 2348 2349 writeswitch = StmtSwitch(typevar) 2350 write = prelude + [ 2351 StmtDecl(Decl(Type.INT, typevar.name), init=ud.callType(cls.var)), 2352 cls.checkedWrite( 2353 None, typevar, cls.msgvar, sentinelKey=uniontype.name(), actor=cls.actor 2354 ), 2355 Whitespace.NL, 2356 writeswitch, 2357 ] 2358 2359 readswitch = StmtSwitch(typevar) 2360 read = prelude + [ 2361 StmtDecl(Decl(Type.INT, typevar.name), init=ExprLiteral.ZERO), 2362 cls._checkedRead( 2363 None, 2364 ExprAddrOf(typevar), 2365 uniontype.name(), 2366 "type of union " + uniontype.name(), 2367 ), 2368 Whitespace.NL, 2369 readswitch, 2370 ] 2371 2372 for c in ud.components: 2373 ct = c.ipdltype 2374 caselabel = CaseLabel(alias + "::" + c.enum()) 2375 origenum = c.enum() 2376 2377 writecase = StmtBlock() 2378 wstmt = cls.checkedWrite( 2379 c.ipdltype, 2380 ExprCall(ExprSelect(cls.var, ".", c.getTypeName())), 2381 cls.msgvar, 2382 sentinelKey=c.enum(), 2383 actor=cls.actor, 2384 ) 2385 if c.special: 2386 # Report an error if the type is special and the side is wrong 2387 wstmt = cls.ifsideis(c.side, wstmt, els=cls.fatalError("wrong side!")) 2388 writecase.addstmts([wstmt, StmtReturn()]) 2389 writeswitch.addcase(caselabel, writecase) 2390 2391 readcase = StmtBlock() 2392 if c.special: 2393 # The type comes across flipped from what the actor will be on 2394 # this side; i.e. child->parent messages will have PFooChild 2395 # when received on the parent side. Report an error if the sides 2396 # match, and handle c.other instead. 2397 readcase.addstmt( 2398 cls.ifsideis( 2399 c.side, 2400 StmtBlock([cls.fatalError("wrong side!"), StmtReturn.FALSE]), 2401 ) 2402 ) 2403 c = c.other 2404 tmpvar = ExprVar("tmp") 2405 ct = c.bareType(fq=True) 2406 readcase.addstmts( 2407 [ 2408 StmtDecl(Decl(ct, tmpvar.name), init=c.defaultValue(fq=True)), 2409 StmtExpr(ExprAssn(ExprDeref(cls.var), ExprMove(tmpvar))), 2410 cls._checkedRead( 2411 c.ipdltype, 2412 ExprAddrOf( 2413 ExprCall(ExprSelect(cls.var, "->", c.getTypeName())) 2414 ), 2415 origenum, 2416 "variant " + origenum + " of union " + uniontype.name(), 2417 ), 2418 StmtReturn.TRUE, 2419 ] 2420 ) 2421 readswitch.addcase(caselabel, readcase) 2422 2423 # Add the error default case 2424 writeswitch.addcase( 2425 DefaultLabel(), 2426 StmtBlock([cls.fatalError("unknown union type"), StmtReturn()]), 2427 ) 2428 readswitch.addcase( 2429 DefaultLabel(), 2430 StmtBlock([cls.fatalError("unknown union type"), StmtReturn.FALSE]), 2431 ) 2432 2433 return cls.generateDecl(cxxtype, write, read) 2434 2435 2436# -------------------------------------------------- 2437 2438 2439class _ComputeTypeDeps(TypeVisitor): 2440 """Pass that gathers the C++ types that a particular IPDL type 2441 (recursively) depends on. There are three kinds of dependencies: (i) 2442 types that need forward declaration; (ii) types that need a |using| 2443 stmt; (iii) IPDL structs or unions which must be fully declared 2444 before this struct. Some types generate multiple kinds.""" 2445 2446 def __init__(self, fortype, unqualifiedTypedefs=False): 2447 ipdl.type.TypeVisitor.__init__(self) 2448 self.usingTypedefs = [] 2449 self.forwardDeclStmts = [] 2450 self.fullDeclTypes = [] 2451 self.fortype = fortype 2452 self.unqualifiedTypedefs = unqualifiedTypedefs 2453 2454 def maybeTypedef(self, fqname, name, templateargs=[]): 2455 if fqname != name or self.unqualifiedTypedefs: 2456 self.usingTypedefs.append(Typedef(Type(fqname), name, templateargs)) 2457 2458 def visitImportedCxxType(self, t): 2459 if t in self.visited: 2460 return 2461 self.visited.add(t) 2462 self.maybeTypedef(t.fullname(), t.name()) 2463 2464 def visitActorType(self, t): 2465 if t in self.visited: 2466 return 2467 self.visited.add(t) 2468 2469 fqname, name = t.fullname(), t.name() 2470 2471 self.maybeTypedef(_actorName(fqname, "Parent"), _actorName(name, "Parent")) 2472 self.maybeTypedef(_actorName(fqname, "Child"), _actorName(name, "Child")) 2473 2474 self.forwardDeclStmts.extend( 2475 [ 2476 _makeForwardDeclForActor(t.protocol, "parent"), 2477 Whitespace.NL, 2478 _makeForwardDeclForActor(t.protocol, "child"), 2479 Whitespace.NL, 2480 ] 2481 ) 2482 2483 def visitStructOrUnionType(self, su, defaultVisit): 2484 if su in self.visited or su == self.fortype: 2485 return 2486 self.visited.add(su) 2487 self.maybeTypedef(su.fullname(), su.name()) 2488 2489 # Mutually recursive fields in unions are behind indirection, so we only 2490 # need a forward decl, and don't need a full type declaration. 2491 if isinstance(self.fortype, UnionType) and self.fortype.mutuallyRecursiveWith( 2492 su 2493 ): 2494 self.forwardDeclStmts.append(_makeForwardDecl(su)) 2495 else: 2496 self.fullDeclTypes.append(su) 2497 2498 return defaultVisit(self, su) 2499 2500 def visitStructType(self, t): 2501 return self.visitStructOrUnionType(t, TypeVisitor.visitStructType) 2502 2503 def visitUnionType(self, t): 2504 return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType) 2505 2506 def visitArrayType(self, t): 2507 return TypeVisitor.visitArrayType(self, t) 2508 2509 def visitMaybeType(self, m): 2510 return TypeVisitor.visitMaybeType(self, m) 2511 2512 def visitShmemType(self, s): 2513 if s in self.visited: 2514 return 2515 self.visited.add(s) 2516 self.maybeTypedef("mozilla::ipc::Shmem", "Shmem") 2517 2518 def visitByteBufType(self, s): 2519 if s in self.visited: 2520 return 2521 self.visited.add(s) 2522 self.maybeTypedef("mozilla::ipc::ByteBuf", "ByteBuf") 2523 2524 def visitFDType(self, s): 2525 if s in self.visited: 2526 return 2527 self.visited.add(s) 2528 self.maybeTypedef("mozilla::ipc::FileDescriptor", "FileDescriptor") 2529 2530 def visitEndpointType(self, s): 2531 if s in self.visited: 2532 return 2533 self.visited.add(s) 2534 self.maybeTypedef("mozilla::ipc::Endpoint", "Endpoint", ["FooSide"]) 2535 self.visitActorType(s.actor) 2536 2537 def visitManagedEndpointType(self, s): 2538 if s in self.visited: 2539 return 2540 self.visited.add(s) 2541 self.maybeTypedef( 2542 "mozilla::ipc::ManagedEndpoint", "ManagedEndpoint", ["FooSide"] 2543 ) 2544 self.visitActorType(s.actor) 2545 2546 def visitUniquePtrType(self, s): 2547 if s in self.visited: 2548 return 2549 self.visited.add(s) 2550 2551 def visitVoidType(self, v): 2552 assert 0 2553 2554 def visitMessageType(self, v): 2555 assert 0 2556 2557 def visitProtocolType(self, v): 2558 assert 0 2559 2560 2561def _fieldStaticAssertions(sd): 2562 staticasserts = [] 2563 for (size, fields) in itertools.groupby( 2564 sd.fields_member_order(), lambda f: pod_size(f.ipdltype) 2565 ): 2566 if size == pod_size_sentinel: 2567 continue 2568 2569 fields = list(fields) 2570 if len(fields) == 1: 2571 continue 2572 2573 staticasserts.append( 2574 StmtCode( 2575 """ 2576 static_assert( 2577 (offsetof(${struct}, ${last}) - offsetof(${struct}, ${first})) == ${expected}, 2578 "Bad assumptions about field layout!"); 2579 """, 2580 struct=sd.name, 2581 first=fields[0].memberVar(), 2582 last=fields[-1].memberVar(), 2583 expected=ExprLiteral.Int(size * (len(fields) - 1)), 2584 ) 2585 ) 2586 2587 return staticasserts 2588 2589 2590def _generateCxxStruct(sd): 2591 """ """ 2592 # compute all the typedefs and forward decls we need to make 2593 gettypedeps = _ComputeTypeDeps(sd.decl.type) 2594 for f in sd.fields: 2595 f.ipdltype.accept(gettypedeps) 2596 2597 usingTypedefs = gettypedeps.usingTypedefs 2598 forwarddeclstmts = gettypedeps.forwardDeclStmts 2599 fulldecltypes = gettypedeps.fullDeclTypes 2600 2601 struct = Class(sd.name, final=True) 2602 struct.addstmts([Label.PRIVATE] + usingTypedefs + [Whitespace.NL, Label.PUBLIC]) 2603 2604 constreftype = Type(sd.name, const=True, ref=True) 2605 2606 def fieldsAsParamList(): 2607 # FIXME Bug 1547019 inType() should do the right thing once 2608 # _cxxTypeCanOnlyMove is replaced with 2609 # _cxxTypeNeedsMoveForSend 2610 return [ 2611 Decl( 2612 f.forceMoveType() if _cxxTypeCanOnlyMove(f.ipdltype) else f.inType(), 2613 f.argVar().name, 2614 ) 2615 for f in sd.fields_ipdl_order() 2616 ] 2617 2618 # If this is an empty struct (no fields), then the default ctor 2619 # and "create-with-fields" ctors are equivalent. So don't bother 2620 # with the default ctor. 2621 if len(sd.fields): 2622 assert len(sd.fields) == len(sd.packed_field_order) 2623 2624 # Struct() 2625 defctor = ConstructorDefn(ConstructorDecl(sd.name, force_inline=True)) 2626 2627 # We want to explicitly default-construct every member of the struct. 2628 # This will initialize all primitives which wouldn't be initialized 2629 # normally to their default values, and will initialize any actor member 2630 # pointers to the correct default value of `nullptr`. Other C++ types 2631 # with custom constructors must also provide a default constructor. 2632 defctor.memberinits = [ 2633 ExprMemberInit(f.memberVar()) for f in sd.fields_member_order() 2634 ] 2635 struct.addstmts([defctor, Whitespace.NL]) 2636 2637 # Struct(const field1& _f1, ...) 2638 valctor = ConstructorDefn( 2639 ConstructorDecl(sd.name, params=fieldsAsParamList(), force_inline=True) 2640 ) 2641 valctor.memberinits = [] 2642 for f in sd.fields_member_order(): 2643 arg = f.argVar() 2644 if _cxxTypeCanOnlyMove(f.ipdltype): 2645 arg = ExprMove(arg) 2646 valctor.memberinits.append(ExprMemberInit(f.memberVar(), args=[arg])) 2647 2648 struct.addstmts([valctor, Whitespace.NL]) 2649 2650 # The default copy, move, and assignment constructors, and the default 2651 # destructor, will do the right thing. 2652 2653 if "Comparable" in sd.attributes: 2654 # bool operator==(const Struct& _o) 2655 ovar = ExprVar("_o") 2656 opeqeq = MethodDefn( 2657 MethodDecl( 2658 "operator==", 2659 params=[Decl(constreftype, ovar.name)], 2660 ret=Type.BOOL, 2661 const=True, 2662 ) 2663 ) 2664 for f in sd.fields_ipdl_order(): 2665 ifneq = StmtIf( 2666 ExprNot( 2667 ExprBinary( 2668 ExprCall(f.getMethod()), "==", ExprCall(f.getMethod(ovar)) 2669 ) 2670 ) 2671 ) 2672 ifneq.addifstmt(StmtReturn.FALSE) 2673 opeqeq.addstmt(ifneq) 2674 opeqeq.addstmt(StmtReturn.TRUE) 2675 struct.addstmts([opeqeq, Whitespace.NL]) 2676 2677 # bool operator!=(const Struct& _o) 2678 opneq = MethodDefn( 2679 MethodDecl( 2680 "operator!=", 2681 params=[Decl(constreftype, ovar.name)], 2682 ret=Type.BOOL, 2683 const=True, 2684 ) 2685 ) 2686 opneq.addstmt(StmtReturn(ExprNot(ExprCall(ExprVar("operator=="), args=[ovar])))) 2687 struct.addstmts([opneq, Whitespace.NL]) 2688 2689 # field1& f1() 2690 # const field1& f1() const 2691 for f in sd.fields_ipdl_order(): 2692 get = MethodDefn( 2693 MethodDecl( 2694 f.getMethod().name, params=[], ret=f.refType(), force_inline=True 2695 ) 2696 ) 2697 get.addstmt(StmtReturn(f.refExpr())) 2698 2699 getconstdecl = deepcopy(get.decl) 2700 getconstdecl.ret = f.constRefType() 2701 getconstdecl.const = True 2702 getconst = MethodDefn(getconstdecl) 2703 getconst.addstmt(StmtReturn(f.constRefExpr())) 2704 2705 struct.addstmts([get, getconst, Whitespace.NL]) 2706 2707 # private: 2708 struct.addstmt(Label.PRIVATE) 2709 2710 # Static assertions to ensure our assumptions about field layout match 2711 # what the compiler is actually producing. We define this as a member 2712 # function, rather than throwing the assertions in the constructor or 2713 # similar, because we don't want to evaluate the static assertions every 2714 # time the header file containing the structure is included. 2715 staticasserts = _fieldStaticAssertions(sd) 2716 if staticasserts: 2717 method = MethodDefn( 2718 MethodDecl("StaticAssertions", params=[], ret=Type.VOID, const=True) 2719 ) 2720 method.addstmts(staticasserts) 2721 struct.addstmts([method]) 2722 2723 # members 2724 struct.addstmts( 2725 [ 2726 StmtDecl(Decl(_effectiveMemberType(f), f.memberVar().name)) 2727 for f in sd.fields_member_order() 2728 ] 2729 ) 2730 2731 return forwarddeclstmts, fulldecltypes, struct 2732 2733 2734def _effectiveMemberType(f): 2735 effective_type = f.bareType() 2736 # Structs must be copyable for backwards compatibility reasons, so we use 2737 # CopyableTArray<T> as their member type for arrays. This is not exposed 2738 # in the method signatures, these keep using nsTArray<T>, which is a base 2739 # class of CopyableTArray<T>. 2740 if effective_type.name == "nsTArray": 2741 effective_type.name = "CopyableTArray" 2742 return effective_type 2743 2744 2745# -------------------------------------------------- 2746 2747 2748def _generateCxxUnion(ud): 2749 # This Union class basically consists of a type (enum) and a 2750 # union for storage. The union can contain POD and non-POD 2751 # types. Each type needs a copy/move ctor, assignment operators, 2752 # and dtor. 2753 # 2754 # Rather than templating this class and only providing 2755 # specializations for the types we support, which is slightly 2756 # "unsafe" in that C++ code can add additional specializations 2757 # without the IPDL compiler's knowledge, we instead explicitly 2758 # implement non-templated methods for each supported type. 2759 # 2760 # The one complication that arises is that C++, for arcane 2761 # reasons, does not allow the placement destructor of a 2762 # builtin type, like int, to be directly invoked. So we need 2763 # to hack around this by internally typedef'ing all 2764 # constituent types. Sigh. 2765 # 2766 # So, for each type, this "Union" class needs: 2767 # (private) 2768 # - entry in the type enum 2769 # - entry in the storage union 2770 # - [type]ptr() method to get a type* from the underlying union 2771 # - same as above to get a const type* 2772 # - typedef to hack around placement delete limitations 2773 # (public) 2774 # - placement delete case for dtor 2775 # - copy ctor 2776 # - move ctor 2777 # - case in generic copy ctor 2778 # - copy operator= impl 2779 # - move operator= impl 2780 # - case in generic operator= 2781 # - operator [type&] 2782 # - operator [const type&] const 2783 # - [type&] get_[type]() 2784 # - [const type&] get_[type]() const 2785 # 2786 cls = Class(ud.name, final=True) 2787 # const Union&, i.e., Union type with inparam semantics 2788 inClsType = Type(ud.name, const=True, ref=True) 2789 refClsType = Type(ud.name, ref=True) 2790 rvalueRefClsType = Type(ud.name, rvalref=True) 2791 typetype = Type("Type") 2792 valuetype = Type("Value") 2793 mtypevar = ExprVar("mType") 2794 mvaluevar = ExprVar("mValue") 2795 maybedtorvar = ExprVar("MaybeDestroy") 2796 assertsanityvar = ExprVar("AssertSanity") 2797 tnonevar = ExprVar("T__None") 2798 tlastvar = ExprVar("T__Last") 2799 2800 def callAssertSanity(uvar=None, expectTypeVar=None): 2801 func = assertsanityvar 2802 args = [] 2803 if uvar is not None: 2804 func = ExprSelect(uvar, ".", assertsanityvar.name) 2805 if expectTypeVar is not None: 2806 args.append(expectTypeVar) 2807 return ExprCall(func, args=args) 2808 2809 def callMaybeDestroy(newTypeVar): 2810 return ExprCall(maybedtorvar, args=[newTypeVar]) 2811 2812 def maybeReconstruct(memb, newTypeVar): 2813 ifdied = StmtIf(callMaybeDestroy(newTypeVar)) 2814 ifdied.addifstmt(StmtExpr(memb.callCtor())) 2815 return ifdied 2816 2817 def voidCast(expr): 2818 return ExprCast(expr, Type.VOID, static=True) 2819 2820 # compute all the typedefs and forward decls we need to make 2821 gettypedeps = _ComputeTypeDeps(ud.decl.type) 2822 for c in ud.components: 2823 c.ipdltype.accept(gettypedeps) 2824 2825 usingTypedefs = gettypedeps.usingTypedefs 2826 forwarddeclstmts = gettypedeps.forwardDeclStmts 2827 fulldecltypes = gettypedeps.fullDeclTypes 2828 2829 # the |Type| enum, used to switch on the discunion's real type 2830 cls.addstmt(Label.PUBLIC) 2831 typeenum = TypeEnum(typetype.name) 2832 typeenum.addId(tnonevar.name, 0) 2833 firstid = ud.components[0].enum() 2834 typeenum.addId(firstid, 1) 2835 for c in ud.components[1:]: 2836 typeenum.addId(c.enum()) 2837 typeenum.addId(tlastvar.name, ud.components[-1].enum()) 2838 cls.addstmts([StmtDecl(Decl(typeenum, "")), Whitespace.NL]) 2839 2840 cls.addstmt(Label.PRIVATE) 2841 cls.addstmts( 2842 usingTypedefs 2843 # hacky typedef's that allow placement dtors of builtins 2844 + [Typedef(c.internalType(), c.typedef()) for c in ud.components] 2845 ) 2846 cls.addstmt(Whitespace.NL) 2847 2848 # the C++ union the discunion use for storage 2849 valueunion = TypeUnion(valuetype.name) 2850 for c in ud.components: 2851 valueunion.addComponent(c.unionType(), c.name) 2852 cls.addstmts([StmtDecl(Decl(valueunion, "")), Whitespace.NL]) 2853 2854 # for each constituent type T, add private accessors that 2855 # return a pointer to the Value union storage casted to |T*| 2856 # and |const T*| 2857 for c in ud.components: 2858 getptr = MethodDefn( 2859 MethodDecl( 2860 c.getPtrName(), params=[], ret=c.ptrToInternalType(), force_inline=True 2861 ) 2862 ) 2863 getptr.addstmt(StmtReturn(c.ptrToSelfExpr())) 2864 2865 getptrconst = MethodDefn( 2866 MethodDecl( 2867 c.getConstPtrName(), 2868 params=[], 2869 ret=c.constPtrToType(), 2870 const=True, 2871 force_inline=True, 2872 ) 2873 ) 2874 getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr())) 2875 2876 cls.addstmts([getptr, getptrconst]) 2877 cls.addstmt(Whitespace.NL) 2878 2879 # add a helper method that invokes the placement dtor on the 2880 # current underlying value, only if |aNewType| is different 2881 # than the current type, and returns true if the underlying 2882 # value needs to be re-constructed 2883 newtypevar = ExprVar("aNewType") 2884 maybedtor = MethodDefn( 2885 MethodDecl( 2886 maybedtorvar.name, params=[Decl(typetype, newtypevar.name)], ret=Type.BOOL 2887 ) 2888 ) 2889 # wasn't /actually/ dtor'd, but it needs to be re-constructed 2890 ifnone = StmtIf(ExprBinary(mtypevar, "==", tnonevar)) 2891 ifnone.addifstmt(StmtReturn.TRUE) 2892 # same type, nothing to see here 2893 ifnochange = StmtIf(ExprBinary(mtypevar, "==", newtypevar)) 2894 ifnochange.addifstmt(StmtReturn.FALSE) 2895 # need to destroy. switch on underlying type 2896 dtorswitch = StmtSwitch(mtypevar) 2897 for c in ud.components: 2898 dtorswitch.addcase( 2899 CaseLabel(c.enum()), StmtBlock([StmtExpr(c.callDtor()), StmtBreak()]) 2900 ) 2901 dtorswitch.addcase( 2902 DefaultLabel(), StmtBlock([_logicError("not reached"), StmtBreak()]) 2903 ) 2904 maybedtor.addstmts([ifnone, ifnochange, dtorswitch, StmtReturn.TRUE]) 2905 cls.addstmts([maybedtor, Whitespace.NL]) 2906 2907 # add helper methods that ensure the discunion has a 2908 # valid type 2909 sanity = MethodDefn( 2910 MethodDecl(assertsanityvar.name, ret=Type.VOID, const=True, force_inline=True) 2911 ) 2912 sanity.addstmts( 2913 [ 2914 _abortIfFalse(ExprBinary(tnonevar, "<=", mtypevar), "invalid type tag"), 2915 _abortIfFalse(ExprBinary(mtypevar, "<=", tlastvar), "invalid type tag"), 2916 ] 2917 ) 2918 cls.addstmt(sanity) 2919 2920 atypevar = ExprVar("aType") 2921 sanity2 = MethodDefn( 2922 MethodDecl( 2923 assertsanityvar.name, 2924 params=[Decl(typetype, atypevar.name)], 2925 ret=Type.VOID, 2926 const=True, 2927 force_inline=True, 2928 ) 2929 ) 2930 sanity2.addstmts( 2931 [ 2932 StmtExpr(ExprCall(assertsanityvar)), 2933 _abortIfFalse(ExprBinary(mtypevar, "==", atypevar), "unexpected type tag"), 2934 ] 2935 ) 2936 cls.addstmts([sanity2, Whitespace.NL]) 2937 2938 # ---- begin public methods ----- 2939 2940 # Union() default ctor 2941 cls.addstmts( 2942 [ 2943 Label.PUBLIC, 2944 ConstructorDefn( 2945 ConstructorDecl(ud.name, force_inline=True), 2946 memberinits=[ExprMemberInit(mtypevar, [tnonevar])], 2947 ), 2948 Whitespace.NL, 2949 ] 2950 ) 2951 2952 # Union(const T&) copy & Union(T&&) move ctors 2953 othervar = ExprVar("aOther") 2954 for c in ud.components: 2955 if not _cxxTypeCanOnlyMove(c.ipdltype): 2956 copyctor = ConstructorDefn( 2957 ConstructorDecl(ud.name, params=[Decl(c.inType(), othervar.name)]) 2958 ) 2959 copyctor.addstmts( 2960 [ 2961 StmtExpr(c.callCtor(othervar)), 2962 StmtExpr(ExprAssn(mtypevar, c.enumvar())), 2963 ] 2964 ) 2965 cls.addstmts([copyctor, Whitespace.NL]) 2966 2967 if not _cxxTypeCanMove(c.ipdltype) or _cxxTypeNeedsMoveForSend(c.ipdltype): 2968 continue 2969 movector = ConstructorDefn( 2970 ConstructorDecl(ud.name, params=[Decl(c.forceMoveType(), othervar.name)]) 2971 ) 2972 movector.addstmts( 2973 [ 2974 StmtExpr(c.callCtor(ExprMove(othervar))), 2975 StmtExpr(ExprAssn(mtypevar, c.enumvar())), 2976 ] 2977 ) 2978 cls.addstmts([movector, Whitespace.NL]) 2979 2980 unionNeedsMove = any(_cxxTypeCanOnlyMove(c.ipdltype) for c in ud.components) 2981 2982 # Union(const Union&) copy ctor 2983 if not unionNeedsMove: 2984 copyctor = ConstructorDefn( 2985 ConstructorDecl(ud.name, params=[Decl(inClsType, othervar.name)]) 2986 ) 2987 othertype = ud.callType(othervar) 2988 copyswitch = StmtSwitch(othertype) 2989 for c in ud.components: 2990 copyswitch.addcase( 2991 CaseLabel(c.enum()), 2992 StmtBlock( 2993 [ 2994 StmtExpr( 2995 c.callCtor( 2996 ExprCall( 2997 ExprSelect(othervar, ".", c.getConstTypeName()) 2998 ) 2999 ) 3000 ), 3001 StmtBreak(), 3002 ] 3003 ), 3004 ) 3005 copyswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()])) 3006 copyswitch.addcase( 3007 DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()]) 3008 ) 3009 copyctor.addstmts( 3010 [ 3011 StmtExpr(callAssertSanity(uvar=othervar)), 3012 copyswitch, 3013 StmtExpr(ExprAssn(mtypevar, othertype)), 3014 ] 3015 ) 3016 cls.addstmts([copyctor, Whitespace.NL]) 3017 3018 # Union(Union&&) move ctor 3019 movector = ConstructorDefn( 3020 ConstructorDecl(ud.name, params=[Decl(rvalueRefClsType, othervar.name)]) 3021 ) 3022 othertypevar = ExprVar("t") 3023 moveswitch = StmtSwitch(othertypevar) 3024 for c in ud.components: 3025 case = StmtBlock() 3026 if c.recursive: 3027 # This is sound as we set othervar.mTypeVar to T__None after the 3028 # switch. The pointer in the union will be left dangling. 3029 case.addstmts( 3030 [ 3031 # ptr_C() = other.ptr_C() 3032 StmtExpr( 3033 ExprAssn( 3034 c.callGetPtr(), 3035 ExprCall( 3036 ExprSelect(othervar, ".", ExprVar(c.getPtrName())) 3037 ), 3038 ) 3039 ) 3040 ] 3041 ) 3042 else: 3043 case.addstmts( 3044 [ 3045 # new ... (Move(other.get_C())) 3046 StmtExpr( 3047 c.callCtor( 3048 ExprMove( 3049 ExprCall(ExprSelect(othervar, ".", c.getTypeName())) 3050 ) 3051 ) 3052 ), 3053 # other.MaybeDestroy(T__None) 3054 StmtExpr( 3055 voidCast( 3056 ExprCall( 3057 ExprSelect(othervar, ".", maybedtorvar), args=[tnonevar] 3058 ) 3059 ) 3060 ), 3061 ] 3062 ) 3063 case.addstmts([StmtBreak()]) 3064 moveswitch.addcase(CaseLabel(c.enum()), case) 3065 moveswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()])) 3066 moveswitch.addcase( 3067 DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()]) 3068 ) 3069 movector.addstmts( 3070 [ 3071 StmtExpr(callAssertSanity(uvar=othervar)), 3072 StmtDecl(Decl(typetype, othertypevar.name), init=ud.callType(othervar)), 3073 moveswitch, 3074 StmtExpr(ExprAssn(ExprSelect(othervar, ".", mtypevar), tnonevar)), 3075 StmtExpr(ExprAssn(mtypevar, othertypevar)), 3076 ] 3077 ) 3078 cls.addstmts([movector, Whitespace.NL]) 3079 3080 # ~Union() 3081 dtor = DestructorDefn(DestructorDecl(ud.name)) 3082 # The void cast prevents Coverity from complaining about missing return 3083 # value checks. 3084 dtor.addstmt(StmtExpr(voidCast(callMaybeDestroy(tnonevar)))) 3085 cls.addstmts([dtor, Whitespace.NL]) 3086 3087 # type() 3088 typemeth = MethodDefn( 3089 MethodDecl("type", ret=typetype, const=True, force_inline=True) 3090 ) 3091 typemeth.addstmt(StmtReturn(mtypevar)) 3092 cls.addstmts([typemeth, Whitespace.NL]) 3093 3094 # Union& operator= methods 3095 rhsvar = ExprVar("aRhs") 3096 for c in ud.components: 3097 if not _cxxTypeCanOnlyMove(c.ipdltype): 3098 # Union& operator=(const T&) 3099 opeq = MethodDefn( 3100 MethodDecl( 3101 "operator=", params=[Decl(c.inType(), rhsvar.name)], ret=refClsType 3102 ) 3103 ) 3104 opeq.addstmts( 3105 [ 3106 # might need to placement-delete old value first 3107 maybeReconstruct(c, c.enumvar()), 3108 StmtExpr(c.callOperatorEq(rhsvar)), 3109 StmtExpr(ExprAssn(mtypevar, c.enumvar())), 3110 StmtReturn(ExprDeref(ExprVar.THIS)), 3111 ] 3112 ) 3113 cls.addstmts([opeq, Whitespace.NL]) 3114 3115 # Union& operator=(T&&) 3116 if not _cxxTypeCanMove(c.ipdltype) or _cxxTypeNeedsMoveForSend(c.ipdltype): 3117 continue 3118 3119 opeq = MethodDefn( 3120 MethodDecl( 3121 "operator=", 3122 params=[Decl(c.forceMoveType(), rhsvar.name)], 3123 ret=refClsType, 3124 ) 3125 ) 3126 opeq.addstmts( 3127 [ 3128 # might need to placement-delete old value first 3129 maybeReconstruct(c, c.enumvar()), 3130 StmtExpr(c.callOperatorEq(ExprMove(rhsvar))), 3131 StmtExpr(ExprAssn(mtypevar, c.enumvar())), 3132 StmtReturn(ExprDeref(ExprVar.THIS)), 3133 ] 3134 ) 3135 cls.addstmts([opeq, Whitespace.NL]) 3136 3137 # Union& operator=(const Union&) 3138 if not unionNeedsMove: 3139 opeq = MethodDefn( 3140 MethodDecl( 3141 "operator=", params=[Decl(inClsType, rhsvar.name)], ret=refClsType 3142 ) 3143 ) 3144 rhstypevar = ExprVar("t") 3145 opeqswitch = StmtSwitch(rhstypevar) 3146 for c in ud.components: 3147 case = StmtBlock() 3148 case.addstmts( 3149 [ 3150 maybeReconstruct(c, rhstypevar), 3151 StmtExpr( 3152 c.callOperatorEq( 3153 ExprCall(ExprSelect(rhsvar, ".", c.getConstTypeName())) 3154 ) 3155 ), 3156 StmtBreak(), 3157 ] 3158 ) 3159 opeqswitch.addcase(CaseLabel(c.enum()), case) 3160 opeqswitch.addcase( 3161 CaseLabel(tnonevar.name), 3162 # The void cast prevents Coverity from complaining about missing return 3163 # value checks. 3164 StmtBlock( 3165 [ 3166 StmtExpr( 3167 ExprCast(callMaybeDestroy(rhstypevar), Type.VOID, static=True) 3168 ), 3169 StmtBreak(), 3170 ] 3171 ), 3172 ) 3173 opeqswitch.addcase( 3174 DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()]) 3175 ) 3176 opeq.addstmts( 3177 [ 3178 StmtExpr(callAssertSanity(uvar=rhsvar)), 3179 StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)), 3180 opeqswitch, 3181 StmtExpr(ExprAssn(mtypevar, rhstypevar)), 3182 StmtReturn(ExprDeref(ExprVar.THIS)), 3183 ] 3184 ) 3185 cls.addstmts([opeq, Whitespace.NL]) 3186 3187 # Union& operator=(Union&&) 3188 opeq = MethodDefn( 3189 MethodDecl( 3190 "operator=", params=[Decl(rvalueRefClsType, rhsvar.name)], ret=refClsType 3191 ) 3192 ) 3193 rhstypevar = ExprVar("t") 3194 opeqswitch = StmtSwitch(rhstypevar) 3195 for c in ud.components: 3196 case = StmtBlock() 3197 if c.recursive: 3198 case.addstmts( 3199 [ 3200 StmtExpr(voidCast(callMaybeDestroy(tnonevar))), 3201 StmtExpr( 3202 ExprAssn( 3203 c.callGetPtr(), 3204 ExprCall(ExprSelect(rhsvar, ".", ExprVar(c.getPtrName()))), 3205 ) 3206 ), 3207 ] 3208 ) 3209 else: 3210 case.addstmts( 3211 [ 3212 maybeReconstruct(c, rhstypevar), 3213 StmtExpr( 3214 c.callOperatorEq( 3215 ExprMove(ExprCall(ExprSelect(rhsvar, ".", c.getTypeName()))) 3216 ) 3217 ), 3218 # other.MaybeDestroy(T__None) 3219 StmtExpr( 3220 voidCast( 3221 ExprCall( 3222 ExprSelect(rhsvar, ".", maybedtorvar), args=[tnonevar] 3223 ) 3224 ) 3225 ), 3226 ] 3227 ) 3228 case.addstmts([StmtBreak()]) 3229 opeqswitch.addcase(CaseLabel(c.enum()), case) 3230 opeqswitch.addcase( 3231 CaseLabel(tnonevar.name), 3232 # The void cast prevents Coverity from complaining about missing return 3233 # value checks. 3234 StmtBlock([StmtExpr(voidCast(callMaybeDestroy(rhstypevar))), StmtBreak()]), 3235 ) 3236 opeqswitch.addcase( 3237 DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()]) 3238 ) 3239 opeq.addstmts( 3240 [ 3241 StmtExpr(callAssertSanity(uvar=rhsvar)), 3242 StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)), 3243 opeqswitch, 3244 StmtExpr(ExprAssn(ExprSelect(rhsvar, ".", mtypevar), tnonevar)), 3245 StmtExpr(ExprAssn(mtypevar, rhstypevar)), 3246 StmtReturn(ExprDeref(ExprVar.THIS)), 3247 ] 3248 ) 3249 cls.addstmts([opeq, Whitespace.NL]) 3250 3251 if "Comparable" in ud.attributes: 3252 # bool operator==(const T&) 3253 for c in ud.components: 3254 opeqeq = MethodDefn( 3255 MethodDecl( 3256 "operator==", 3257 params=[Decl(c.inType(), rhsvar.name)], 3258 ret=Type.BOOL, 3259 const=True, 3260 ) 3261 ) 3262 opeqeq.addstmt( 3263 StmtReturn(ExprBinary(ExprCall(ExprVar(c.getTypeName())), "==", rhsvar)) 3264 ) 3265 cls.addstmts([opeqeq, Whitespace.NL]) 3266 3267 # bool operator==(const Union&) 3268 opeqeq = MethodDefn( 3269 MethodDecl( 3270 "operator==", 3271 params=[Decl(inClsType, rhsvar.name)], 3272 ret=Type.BOOL, 3273 const=True, 3274 ) 3275 ) 3276 iftypesmismatch = StmtIf(ExprBinary(ud.callType(), "!=", ud.callType(rhsvar))) 3277 iftypesmismatch.addifstmt(StmtReturn.FALSE) 3278 opeqeq.addstmts([iftypesmismatch, Whitespace.NL]) 3279 3280 opeqeqswitch = StmtSwitch(ud.callType()) 3281 for c in ud.components: 3282 case = StmtBlock() 3283 case.addstmt( 3284 StmtReturn( 3285 ExprBinary( 3286 ExprCall(ExprVar(c.getTypeName())), 3287 "==", 3288 ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())), 3289 ) 3290 ) 3291 ) 3292 opeqeqswitch.addcase(CaseLabel(c.enum()), case) 3293 opeqeqswitch.addcase( 3294 DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn.FALSE]) 3295 ) 3296 opeqeq.addstmt(opeqeqswitch) 3297 3298 cls.addstmts([opeqeq, Whitespace.NL]) 3299 3300 # accessors for each type: operator T&, operator const T&, 3301 # T& get(), const T& get() 3302 for c in ud.components: 3303 getValueVar = ExprVar(c.getTypeName()) 3304 getConstValueVar = ExprVar(c.getConstTypeName()) 3305 3306 getvalue = MethodDefn( 3307 MethodDecl(getValueVar.name, ret=c.refType(), force_inline=True) 3308 ) 3309 getvalue.addstmts( 3310 [ 3311 StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), 3312 StmtReturn(ExprDeref(c.callGetPtr())), 3313 ] 3314 ) 3315 3316 getconstvalue = MethodDefn( 3317 MethodDecl( 3318 getConstValueVar.name, 3319 ret=c.constRefType(), 3320 const=True, 3321 force_inline=True, 3322 ) 3323 ) 3324 getconstvalue.addstmts( 3325 [ 3326 StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), 3327 StmtReturn(c.getConstValue()), 3328 ] 3329 ) 3330 3331 cls.addstmts([getvalue, getconstvalue]) 3332 3333 optype = MethodDefn(MethodDecl("", typeop=c.refType(), force_inline=True)) 3334 optype.addstmt(StmtReturn(ExprCall(getValueVar))) 3335 opconsttype = MethodDefn( 3336 MethodDecl("", const=True, typeop=c.constRefType(), force_inline=True) 3337 ) 3338 opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar))) 3339 3340 cls.addstmts([optype, opconsttype, Whitespace.NL]) 3341 # private vars 3342 cls.addstmts( 3343 [ 3344 Label.PRIVATE, 3345 StmtDecl(Decl(valuetype, mvaluevar.name)), 3346 StmtDecl(Decl(typetype, mtypevar.name)), 3347 ] 3348 ) 3349 3350 return forwarddeclstmts, fulldecltypes, cls 3351 3352 3353# ----------------------------------------------------------------------------- 3354 3355 3356class _FindFriends(ipdl.ast.Visitor): 3357 def __init__(self): 3358 self.mytype = None # ProtocolType 3359 self.vtype = None # ProtocolType 3360 self.friends = set() # set<ProtocolType> 3361 3362 def findFriends(self, ptype): 3363 self.mytype = ptype 3364 for toplvl in ptype.toplevels(): 3365 self.walkDownTheProtocolTree(toplvl) 3366 return self.friends 3367 3368 # TODO could make this into a _iterProtocolTreeHelper ... 3369 def walkDownTheProtocolTree(self, ptype): 3370 if ptype != self.mytype: 3371 # don't want to |friend| ourself! 3372 self.visit(ptype) 3373 for mtype in ptype.manages: 3374 if mtype is not ptype: 3375 self.walkDownTheProtocolTree(mtype) 3376 3377 def visit(self, ptype): 3378 # |vtype| is the type currently being visited 3379 savedptype = self.vtype 3380 self.vtype = ptype 3381 ptype._ast.accept(self) 3382 self.vtype = savedptype 3383 3384 def visitMessageDecl(self, md): 3385 for it in self.iterActorParams(md): 3386 if it.protocol == self.mytype: 3387 self.friends.add(self.vtype) 3388 3389 def iterActorParams(self, md): 3390 for param in md.inParams: 3391 for actor in ipdl.type.iteractortypes(param.type): 3392 yield actor 3393 for ret in md.outParams: 3394 for actor in ipdl.type.iteractortypes(ret.type): 3395 yield actor 3396 3397 3398class _GenerateProtocolActorCode(ipdl.ast.Visitor): 3399 def __init__(self, myside): 3400 self.side = myside # "parent" or "child" 3401 self.prettyside = myside.title() 3402 self.clsname = None 3403 self.protocol = None 3404 self.hdrfile = None 3405 self.cppfile = None 3406 self.ns = None 3407 self.cls = None 3408 self.includedActorTypedefs = [] 3409 self.protocolCxxIncludes = [] 3410 self.actorForwardDecls = [] 3411 self.usingDecls = [] 3412 self.externalIncludes = set() 3413 self.nonForwardDeclaredHeaders = set() 3414 3415 def lower(self, tu, clsname, cxxHeaderFile, cxxFile): 3416 self.clsname = clsname 3417 self.hdrfile = cxxHeaderFile 3418 self.cppfile = cxxFile 3419 tu.accept(self) 3420 3421 def standardTypedefs(self): 3422 return [ 3423 Typedef(Type("mozilla::ipc::IProtocol"), "IProtocol"), 3424 Typedef(Type("IPC::Message"), "Message"), 3425 Typedef(Type("base::ProcessHandle"), "ProcessHandle"), 3426 Typedef(Type("mozilla::ipc::MessageChannel"), "MessageChannel"), 3427 Typedef(Type("mozilla::ipc::SharedMemory"), "SharedMemory"), 3428 ] 3429 3430 def visitTranslationUnit(self, tu): 3431 self.protocol = tu.protocol 3432 3433 hf = self.hdrfile 3434 cf = self.cppfile 3435 3436 # make the C++ header 3437 hf.addthings( 3438 [_DISCLAIMER] 3439 + _includeGuardStart(hf) 3440 + [ 3441 Whitespace.NL, 3442 CppDirective("include", '"' + _protocolHeaderName(tu.protocol) + '.h"'), 3443 ] 3444 ) 3445 3446 for inc in tu.includes: 3447 inc.accept(self) 3448 for inc in tu.cxxIncludes: 3449 inc.accept(self) 3450 3451 for using in tu.using: 3452 using.accept(self) 3453 3454 # this generates the actor's full impl in self.cls 3455 tu.protocol.accept(self) 3456 3457 clsdecl, clsdefn = _splitClassDeclDefn(self.cls) 3458 3459 # XXX damn C++ ... return types in the method defn aren't in 3460 # class scope 3461 for stmt in clsdefn.stmts: 3462 if isinstance(stmt, MethodDefn): 3463 if stmt.decl.ret and stmt.decl.ret.name == "Result": 3464 stmt.decl.ret.name = clsdecl.name + "::" + stmt.decl.ret.name 3465 3466 def setToIncludes(s): 3467 return [CppDirective("include", '"%s"' % i) for i in sorted(iter(s))] 3468 3469 def makeNamespace(p, file): 3470 if 0 == len(p.namespaces): 3471 return file 3472 ns = Namespace(p.namespaces[-1].name) 3473 outerns = _putInNamespaces(ns, p.namespaces[:-1]) 3474 file.addthing(outerns) 3475 return ns 3476 3477 if len(self.nonForwardDeclaredHeaders) != 0: 3478 self.hdrfile.addthings( 3479 [ 3480 Whitespace("// Headers for things that cannot be forward declared"), 3481 Whitespace.NL, 3482 ] 3483 + setToIncludes(self.nonForwardDeclaredHeaders) 3484 + [Whitespace.NL] 3485 ) 3486 self.hdrfile.addthings(self.actorForwardDecls) 3487 self.hdrfile.addthings(self.usingDecls) 3488 3489 hdrns = makeNamespace(self.protocol, self.hdrfile) 3490 hdrns.addstmts( 3491 [Whitespace.NL, Whitespace.NL, clsdecl, Whitespace.NL, Whitespace.NL] 3492 ) 3493 3494 actortype = ActorType(tu.protocol.decl.type) 3495 traitsdecl, traitsdefn = _ParamTraits.actorPickling(actortype, self.side) 3496 3497 self.hdrfile.addthings([traitsdecl, Whitespace.NL] + _includeGuardEnd(hf)) 3498 3499 # make the .cpp file 3500 if (self.protocol.name, self.side) not in VIRTUAL_CALL_CLASSES: 3501 if (self.protocol.name, self.side) in DIRECT_CALL_OVERRIDES: 3502 (_, header_file) = DIRECT_CALL_OVERRIDES[self.protocol.name, self.side] 3503 else: 3504 assert self.protocol.name.startswith("P") 3505 header_file = "{}/{}{}.h".format( 3506 "/".join(n.name for n in self.protocol.namespaces), 3507 self.protocol.name[1:], 3508 self.side.capitalize(), 3509 ) 3510 self.externalIncludes.add(header_file) 3511 3512 cf.addthings( 3513 [ 3514 _DISCLAIMER, 3515 Whitespace.NL, 3516 CppDirective( 3517 "include", 3518 '"' + _protocolHeaderName(self.protocol, self.side) + '.h"', 3519 ), 3520 ] 3521 + setToIncludes(self.externalIncludes) 3522 ) 3523 3524 cf.addthings( 3525 ( 3526 [Whitespace.NL] 3527 + [ 3528 CppDirective("include", '"%s.h"' % (inc)) 3529 for inc in self.protocolCxxIncludes 3530 ] 3531 + [Whitespace.NL] 3532 + [ 3533 CppDirective("include", '"%s"' % filename) 3534 for filename in ipdl.builtin.CppIncludes 3535 ] 3536 + [Whitespace.NL] 3537 ) 3538 ) 3539 3540 cppns = makeNamespace(self.protocol, cf) 3541 cppns.addstmts( 3542 [Whitespace.NL, Whitespace.NL, clsdefn, Whitespace.NL, Whitespace.NL] 3543 ) 3544 3545 cf.addthing(traitsdefn) 3546 3547 def visitUsingStmt(self, using): 3548 if using.header is None: 3549 return 3550 3551 if using.canBeForwardDeclared() and not using.decl.type.isUniquePtr(): 3552 spec = using.type.spec 3553 3554 self.usingDecls.extend( 3555 [ 3556 _makeForwardDeclForQClass( 3557 spec.baseid, 3558 spec.quals, 3559 cls=using.isClass(), 3560 struct=using.isStruct(), 3561 ), 3562 Whitespace.NL, 3563 ] 3564 ) 3565 self.externalIncludes.add(using.header) 3566 else: 3567 self.nonForwardDeclaredHeaders.add(using.header) 3568 3569 def visitCxxInclude(self, inc): 3570 self.externalIncludes.add(inc.file) 3571 3572 def visitInclude(self, inc): 3573 ip = inc.tu.protocol 3574 if not ip: 3575 return 3576 3577 self.actorForwardDecls.extend( 3578 [ 3579 _makeForwardDeclForActor(ip.decl.type, self.side), 3580 _makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)), 3581 Whitespace.NL, 3582 ] 3583 ) 3584 self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side)) 3585 3586 if ip.decl.fullname is not None: 3587 self.includedActorTypedefs.append( 3588 Typedef( 3589 Type(_actorName(ip.decl.fullname, self.side.title())), 3590 _actorName(ip.decl.shortname, self.side.title()), 3591 ) 3592 ) 3593 3594 self.includedActorTypedefs.append( 3595 Typedef( 3596 Type(_actorName(ip.decl.fullname, _otherSide(self.side).title())), 3597 _actorName(ip.decl.shortname, _otherSide(self.side).title()), 3598 ) 3599 ) 3600 3601 def visitProtocol(self, p): 3602 self.hdrfile.addcode( 3603 """ 3604 #ifdef DEBUG 3605 #include "prenv.h" 3606 #endif // DEBUG 3607 3608 #include "mozilla/Tainting.h" 3609 #include "mozilla/ipc/MessageChannel.h" 3610 #include "mozilla/ipc/ProtocolUtils.h" 3611 """ 3612 ) 3613 3614 self.protocol = p 3615 ptype = p.decl.type 3616 toplevel = p.decl.type.toplevel() 3617 3618 hasAsyncReturns = False 3619 for md in p.messageDecls: 3620 if md.hasAsyncReturns(): 3621 hasAsyncReturns = True 3622 break 3623 3624 inherits = [] 3625 if ptype.isToplevel(): 3626 inherits.append(Inherit(p.openedProtocolInterfaceType(), viz="public")) 3627 else: 3628 inherits.append(Inherit(p.managerInterfaceType(), viz="public")) 3629 3630 if ptype.isToplevel() and self.side == "parent": 3631 self.hdrfile.addthings( 3632 [_makeForwardDeclForQClass("nsIFile", []), Whitespace.NL] 3633 ) 3634 3635 self.cls = Class(self.clsname, inherits=inherits, abstract=True) 3636 3637 self.cls.addstmt(Label.PRIVATE) 3638 friends = _FindFriends().findFriends(ptype) 3639 if ptype.isManaged(): 3640 friends.update(ptype.managers) 3641 3642 # |friend| managed actors so that they can call our Dealloc*() 3643 friends.update(ptype.manages) 3644 3645 # don't friend ourself if we're a self-managed protocol 3646 friends.discard(ptype) 3647 3648 for friend in sorted(friends, key=lambda f: f.fullname()): 3649 self.actorForwardDecls.extend( 3650 [_makeForwardDeclForActor(friend, self.prettyside), Whitespace.NL] 3651 ) 3652 self.cls.addstmt( 3653 FriendClassDecl(_actorName(friend.fullname(), self.prettyside)) 3654 ) 3655 3656 self.cls.addstmt(Label.PROTECTED) 3657 for typedef in p.cxxTypedefs(): 3658 self.cls.addstmt(typedef) 3659 for typedef in self.includedActorTypedefs: 3660 self.cls.addstmt(typedef) 3661 3662 self.cls.addstmt(Whitespace.NL) 3663 3664 if hasAsyncReturns: 3665 self.cls.addstmt(Label.PUBLIC) 3666 for md in p.messageDecls: 3667 if self.sendsMessage(md) and md.hasAsyncReturns(): 3668 self.cls.addstmt( 3669 Typedef(_makePromise(md.returns, self.side), md.promiseName()) 3670 ) 3671 if self.receivesMessage(md) and md.hasAsyncReturns(): 3672 self.cls.addstmt( 3673 Typedef(_makeResolver(md.returns, self.side), md.resolverName()) 3674 ) 3675 self.cls.addstmt(Whitespace.NL) 3676 3677 self.cls.addstmt(Label.PROTECTED) 3678 # interface methods that the concrete subclass has to impl 3679 for md in p.messageDecls: 3680 isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor() 3681 3682 if self.receivesMessage(md): 3683 # generate Recv/Answer* interface 3684 implicit = not isdtor 3685 returnsems = "resolver" if md.decl.type.isAsync() else "out" 3686 recvDecl = MethodDecl( 3687 md.recvMethod(), 3688 params=md.makeCxxParams( 3689 paramsems="move", 3690 returnsems=returnsems, 3691 side=self.side, 3692 implicit=implicit, 3693 direction="recv", 3694 ), 3695 ret=Type("mozilla::ipc::IPCResult"), 3696 methodspec=MethodSpec.VIRTUAL, 3697 ) 3698 3699 # These method implementations cause problems when trying to 3700 # override them with different types in a direct call class. 3701 # 3702 # For the `isdtor` case there's a simple solution: it doesn't 3703 # make much sense to specify arguments and then completely 3704 # ignore them, and the no-arg case isn't a problem for 3705 # overriding. 3706 if isctor or (isdtor and not md.inParams): 3707 defaultRecv = MethodDefn(recvDecl) 3708 defaultRecv.addcode("return IPC_OK();\n") 3709 self.cls.addstmt(defaultRecv) 3710 elif (self.protocol.name, self.side) in VIRTUAL_CALL_CLASSES: 3711 # If we're using virtual calls, we need the methods to be 3712 # declared on the base class. 3713 recvDecl.methodspec = MethodSpec.PURE 3714 self.cls.addstmt(StmtDecl(recvDecl)) 3715 3716 # If we're using virtual calls, we need the methods to be declared on 3717 # the base class. 3718 if (self.protocol.name, self.side) in VIRTUAL_CALL_CLASSES: 3719 for md in p.messageDecls: 3720 managed = md.decl.type.constructedType() 3721 if not ptype.isManagerOf(managed) or md.decl.type.isDtor(): 3722 continue 3723 3724 # add the Alloc interface for managed actors 3725 actortype = md.actorDecl().bareType(self.side) 3726 3727 if managed.isRefcounted(): 3728 if not self.receivesMessage(md): 3729 continue 3730 3731 actortype.ptr = False 3732 actortype = _alreadyaddrefed(actortype) 3733 3734 self.cls.addstmt( 3735 StmtDecl( 3736 MethodDecl( 3737 _allocMethod(managed, self.side), 3738 params=md.makeCxxParams( 3739 side=self.side, implicit=False, direction="recv" 3740 ), 3741 ret=actortype, 3742 methodspec=MethodSpec.PURE, 3743 ) 3744 ) 3745 ) 3746 3747 # add the Dealloc interface for all managed non-refcounted actors, 3748 # even without ctors. This is useful for protocols which use 3749 # ManagedEndpoint for construction. 3750 for managed in ptype.manages: 3751 if managed.isRefcounted(): 3752 continue 3753 3754 self.cls.addstmt( 3755 StmtDecl( 3756 MethodDecl( 3757 _deallocMethod(managed, self.side), 3758 params=[ 3759 Decl(p.managedCxxType(managed, self.side), "aActor") 3760 ], 3761 ret=Type.BOOL, 3762 methodspec=MethodSpec.PURE, 3763 ) 3764 ) 3765 ) 3766 3767 if ptype.isToplevel(): 3768 # void ProcessingError(code); default to no-op 3769 processingerror = MethodDefn( 3770 MethodDecl( 3771 p.processingErrorVar().name, 3772 params=[ 3773 Param(_Result.Type(), "aCode"), 3774 Param(Type("char", const=True, ptr=True), "aReason"), 3775 ], 3776 methodspec=MethodSpec.OVERRIDE, 3777 ) 3778 ) 3779 3780 # bool ShouldContinueFromReplyTimeout(); default to |true| 3781 shouldcontinue = MethodDefn( 3782 MethodDecl( 3783 p.shouldContinueFromTimeoutVar().name, 3784 ret=Type.BOOL, 3785 methodspec=MethodSpec.OVERRIDE, 3786 ) 3787 ) 3788 shouldcontinue.addcode("return true;\n") 3789 3790 # void Entered*()/Exited*(); default to no-op 3791 entered = MethodDefn( 3792 MethodDecl(p.enteredCxxStackVar().name, methodspec=MethodSpec.OVERRIDE) 3793 ) 3794 exited = MethodDefn( 3795 MethodDecl(p.exitedCxxStackVar().name, methodspec=MethodSpec.OVERRIDE) 3796 ) 3797 enteredcall = MethodDefn( 3798 MethodDecl(p.enteredCallVar().name, methodspec=MethodSpec.OVERRIDE) 3799 ) 3800 exitedcall = MethodDefn( 3801 MethodDecl(p.exitedCallVar().name, methodspec=MethodSpec.OVERRIDE) 3802 ) 3803 3804 self.cls.addstmts( 3805 [ 3806 processingerror, 3807 shouldcontinue, 3808 entered, 3809 exited, 3810 enteredcall, 3811 exitedcall, 3812 Whitespace.NL, 3813 ] 3814 ) 3815 3816 self.cls.addstmts(([Label.PUBLIC] + self.standardTypedefs() + [Whitespace.NL])) 3817 3818 self.cls.addstmt(Label.PUBLIC) 3819 # Actor() 3820 ctor = ConstructorDefn(ConstructorDecl(self.clsname)) 3821 side = ExprVar("mozilla::ipc::" + self.side.title() + "Side") 3822 if ptype.isToplevel(): 3823 name = ExprLiteral.String(_actorName(p.name, self.side)) 3824 ctor.memberinits = [ 3825 ExprMemberInit( 3826 ExprVar("mozilla::ipc::IToplevelProtocol"), 3827 [name, _protocolId(ptype), side], 3828 ) 3829 ] 3830 else: 3831 ctor.memberinits = [ 3832 ExprMemberInit( 3833 ExprVar("mozilla::ipc::IProtocol"), [_protocolId(ptype), side] 3834 ) 3835 ] 3836 3837 ctor.addcode("MOZ_COUNT_CTOR(${clsname});\n", clsname=self.clsname) 3838 self.cls.addstmts([ctor, Whitespace.NL]) 3839 3840 # ~Actor() 3841 dtor = DestructorDefn( 3842 DestructorDecl(self.clsname, methodspec=MethodSpec.VIRTUAL) 3843 ) 3844 dtor.addcode("MOZ_COUNT_DTOR(${clsname});\n", clsname=self.clsname) 3845 3846 self.cls.addstmts([dtor, Whitespace.NL]) 3847 3848 if ptype.isRefcounted(): 3849 self.cls.addcode( 3850 """ 3851 NS_INLINE_DECL_PURE_VIRTUAL_REFCOUNTING 3852 """ 3853 ) 3854 self.cls.addstmt(Label.PROTECTED) 3855 self.cls.addcode( 3856 """ 3857 void ActorAlloc() final { AddRef(); } 3858 void ActorDealloc() final { Release(); } 3859 """ 3860 ) 3861 3862 self.cls.addstmt(Label.PUBLIC) 3863 if not ptype.isToplevel(): 3864 if 1 == len(p.managers): 3865 # manager() const 3866 managertype = p.managerActorType(self.side, ptr=True) 3867 managermeth = MethodDefn( 3868 MethodDecl("Manager", ret=managertype, const=True) 3869 ) 3870 managermeth.addcode( 3871 """ 3872 return static_cast<${type}>(IProtocol::Manager()); 3873 """, 3874 type=managertype, 3875 ) 3876 3877 self.cls.addstmts([managermeth, Whitespace.NL]) 3878 3879 def actorFromIter(itervar): 3880 return ExprCode("${iter}.Get()->GetKey()", iter=itervar) 3881 3882 def forLoopOverHashtable(hashtable, itervar, const=False): 3883 itermeth = "ConstIter" if const else "Iter" 3884 return StmtFor( 3885 init=ExprCode( 3886 "auto ${itervar} = ${hashtable}.${itermeth}()", 3887 itervar=itervar, 3888 hashtable=hashtable, 3889 itermeth=itermeth, 3890 ), 3891 cond=ExprCode("!${itervar}.Done()", itervar=itervar), 3892 update=ExprCode("${itervar}.Next()", itervar=itervar), 3893 ) 3894 3895 # Managed[T](Array& inout) const 3896 # const Array<T>& Managed() const 3897 for managed in ptype.manages: 3898 container = p.managedVar(managed, self.side) 3899 3900 meth = MethodDefn( 3901 MethodDecl( 3902 p.managedMethod(managed, self.side).name, 3903 params=[ 3904 Decl( 3905 _cxxArrayType( 3906 p.managedCxxType(managed, self.side), ref=True 3907 ), 3908 "aArr", 3909 ) 3910 ], 3911 const=True, 3912 ) 3913 ) 3914 meth.addcode("${container}.ToArray(aArr);\n", container=container) 3915 3916 refmeth = MethodDefn( 3917 MethodDecl( 3918 p.managedMethod(managed, self.side).name, 3919 params=[], 3920 ret=p.managedVarType(managed, self.side, const=True, ref=True), 3921 const=True, 3922 ) 3923 ) 3924 refmeth.addcode("return ${container};\n", container=container) 3925 3926 self.cls.addstmts([meth, refmeth, Whitespace.NL]) 3927 3928 # AllManagedActors(Array& inout) const 3929 arrvar = ExprVar("arr__") 3930 managedmeth = MethodDefn( 3931 MethodDecl( 3932 "AllManagedActors", 3933 params=[ 3934 Decl( 3935 _cxxArrayType(_refptr(_cxxLifecycleProxyType()), ref=True), 3936 arrvar.name, 3937 ) 3938 ], 3939 methodspec=MethodSpec.OVERRIDE, 3940 const=True, 3941 ) 3942 ) 3943 3944 # Count the number of managed actors, and allocate space in the output array. 3945 managedmeth.addcode( 3946 """ 3947 uint32_t total = 0; 3948 """ 3949 ) 3950 for managed in ptype.manages: 3951 managedmeth.addcode( 3952 """ 3953 total += ${container}.Count(); 3954 """, 3955 container=p.managedVar(managed, self.side), 3956 ) 3957 managedmeth.addcode( 3958 """ 3959 arr__.SetCapacity(total); 3960 3961 """ 3962 ) 3963 3964 for managed in ptype.manages: 3965 managedmeth.addcode( 3966 """ 3967 for (auto* key : ${container}) { 3968 arr__.AppendElement(key->GetLifecycleProxy()); 3969 } 3970 3971 """, 3972 container=p.managedVar(managed, self.side), 3973 ) 3974 3975 self.cls.addstmts([managedmeth, Whitespace.NL]) 3976 3977 # OpenPEndpoint(...)/BindPEndpoint(...) 3978 for managed in ptype.manages: 3979 self.genManagedEndpoint(managed) 3980 3981 # OnMessageReceived()/OnCallReceived() 3982 3983 # save these away for use in message handler case stmts 3984 msgvar = ExprVar("msg__") 3985 self.msgvar = msgvar 3986 replyvar = ExprVar("reply__") 3987 self.replyvar = replyvar 3988 itervar = ExprVar("iter__") 3989 self.itervar = itervar 3990 var = ExprVar("v__") 3991 self.var = var 3992 # for ctor recv cases, we can't read the actor ID into a PFoo* 3993 # because it doesn't exist on this side yet. Use a "special" 3994 # actor handle instead 3995 handlevar = ExprVar("handle__") 3996 self.handlevar = handlevar 3997 3998 msgtype = ExprCode("msg__.type()") 3999 self.asyncSwitch = StmtSwitch(msgtype) 4000 self.syncSwitch = None 4001 self.interruptSwitch = None 4002 if toplevel.isSync() or toplevel.isInterrupt(): 4003 self.syncSwitch = StmtSwitch(msgtype) 4004 if toplevel.isInterrupt(): 4005 self.interruptSwitch = StmtSwitch(msgtype) 4006 4007 # implement Send*() methods and add dispatcher cases to 4008 # message switch()es 4009 for md in p.messageDecls: 4010 self.visitMessageDecl(md) 4011 4012 # add default cases 4013 default = StmtCode( 4014 """ 4015 return MsgNotKnown; 4016 """ 4017 ) 4018 self.asyncSwitch.addcase(DefaultLabel(), default) 4019 if toplevel.isSync() or toplevel.isInterrupt(): 4020 self.syncSwitch.addcase(DefaultLabel(), default) 4021 if toplevel.isInterrupt(): 4022 self.interruptSwitch.addcase(DefaultLabel(), default) 4023 4024 self.cls.addstmts(self.implementManagerIface()) 4025 4026 def makeHandlerMethod(name, switch, hasReply, dispatches=False): 4027 params = [Decl(Type("Message", const=True, ref=True), msgvar.name)] 4028 if hasReply: 4029 params.append(Decl(Type("Message", ref=True, ptr=True), replyvar.name)) 4030 4031 method = MethodDefn( 4032 MethodDecl( 4033 name, 4034 methodspec=MethodSpec.OVERRIDE, 4035 params=params, 4036 ret=_Result.Type(), 4037 ) 4038 ) 4039 4040 if not switch: 4041 method.addcode( 4042 """ 4043 MOZ_ASSERT_UNREACHABLE("message protocol not supported"); 4044 return MsgNotKnown; 4045 """ 4046 ) 4047 return method 4048 4049 if dispatches: 4050 if hasReply: 4051 ondeadactor = [StmtReturn(_Result.RouteError)] 4052 else: 4053 ondeadactor = [ 4054 self.logMessage( 4055 None, ExprAddrOf(msgvar), "Ignored message for dead actor" 4056 ), 4057 StmtReturn(_Result.Processed), 4058 ] 4059 4060 method.addcode( 4061 """ 4062 int32_t route__ = ${msgvar}.routing_id(); 4063 if (MSG_ROUTING_CONTROL != route__) { 4064 IProtocol* routed__ = Lookup(route__); 4065 if (!routed__ || !routed__->GetLifecycleProxy()) { 4066 $*{ondeadactor} 4067 } 4068 4069 RefPtr<mozilla::ipc::ActorLifecycleProxy> proxy__ = 4070 routed__->GetLifecycleProxy(); 4071 return proxy__->Get()->${name}($,{args}); 4072 } 4073 4074 """, 4075 msgvar=msgvar, 4076 ondeadactor=ondeadactor, 4077 name=name, 4078 args=[p.name for p in params], 4079 ) 4080 4081 # bug 509581: don't generate the switch stmt if there 4082 # is only the default case; MSVC doesn't like that 4083 if switch.nr_cases > 1: 4084 method.addstmt(switch) 4085 else: 4086 method.addstmt(StmtReturn(_Result.NotKnown)) 4087 4088 return method 4089 4090 dispatches = ptype.isToplevel() and ptype.isManager() 4091 self.cls.addstmts( 4092 [ 4093 makeHandlerMethod( 4094 "OnMessageReceived", 4095 self.asyncSwitch, 4096 hasReply=False, 4097 dispatches=dispatches, 4098 ), 4099 Whitespace.NL, 4100 ] 4101 ) 4102 self.cls.addstmts( 4103 [ 4104 makeHandlerMethod( 4105 "OnMessageReceived", 4106 self.syncSwitch, 4107 hasReply=True, 4108 dispatches=dispatches, 4109 ), 4110 Whitespace.NL, 4111 ] 4112 ) 4113 self.cls.addstmts( 4114 [ 4115 makeHandlerMethod( 4116 "OnCallReceived", 4117 self.interruptSwitch, 4118 hasReply=True, 4119 dispatches=dispatches, 4120 ), 4121 Whitespace.NL, 4122 ] 4123 ) 4124 4125 clearsubtreevar = ExprVar("ClearSubtree") 4126 4127 if ptype.isToplevel(): 4128 # OnChannelClose() 4129 onclose = MethodDefn( 4130 MethodDecl("OnChannelClose", methodspec=MethodSpec.OVERRIDE) 4131 ) 4132 onclose.addcode( 4133 """ 4134 DestroySubtree(NormalShutdown); 4135 ClearSubtree(); 4136 DeallocShmems(); 4137 if (GetLifecycleProxy()) { 4138 GetLifecycleProxy()->Release(); 4139 } 4140 """ 4141 ) 4142 self.cls.addstmts([onclose, Whitespace.NL]) 4143 4144 # OnChannelError() 4145 onerror = MethodDefn( 4146 MethodDecl("OnChannelError", methodspec=MethodSpec.OVERRIDE) 4147 ) 4148 onerror.addcode( 4149 """ 4150 DestroySubtree(AbnormalShutdown); 4151 ClearSubtree(); 4152 DeallocShmems(); 4153 if (GetLifecycleProxy()) { 4154 GetLifecycleProxy()->Release(); 4155 } 4156 """ 4157 ) 4158 self.cls.addstmts([onerror, Whitespace.NL]) 4159 4160 if ptype.isToplevel() and ptype.isInterrupt(): 4161 processnative = MethodDefn( 4162 MethodDecl("ProcessNativeEventsInInterruptCall", ret=Type.VOID) 4163 ) 4164 processnative.addcode( 4165 """ 4166 #ifdef OS_WIN 4167 GetIPCChannel()->ProcessNativeEventsInInterruptCall(); 4168 #else 4169 FatalError("This method is Windows-only"); 4170 #endif 4171 """ 4172 ) 4173 4174 self.cls.addstmts([processnative, Whitespace.NL]) 4175 4176 # private methods 4177 self.cls.addstmt(Label.PRIVATE) 4178 4179 # ClearSubtree() 4180 clearsubtree = MethodDefn(MethodDecl(clearsubtreevar.name)) 4181 for managed in ptype.manages: 4182 clearsubtree.addcode( 4183 """ 4184 for (auto* key : ${container}) { 4185 key->ClearSubtree(); 4186 } 4187 for (auto* key : ${container}) { 4188 // Recursively releasing ${container} kids. 4189 auto* proxy = key->GetLifecycleProxy(); 4190 NS_IF_RELEASE(proxy); 4191 } 4192 ${container}.Clear(); 4193 4194 """, 4195 container=p.managedVar(managed, self.side), 4196 ) 4197 4198 # don't release our own IPC reference: either the manager will do it, 4199 # or we're toplevel 4200 self.cls.addstmts([clearsubtree, Whitespace.NL]) 4201 4202 for managed in ptype.manages: 4203 self.cls.addstmts( 4204 [ 4205 StmtDecl( 4206 Decl( 4207 p.managedVarType(managed, self.side), 4208 p.managedVar(managed, self.side).name, 4209 ) 4210 ) 4211 ] 4212 ) 4213 4214 def genManagedEndpoint(self, managed): 4215 hereEp = "ManagedEndpoint<%s>" % _actorName(managed.name(), self.side) 4216 thereEp = "ManagedEndpoint<%s>" % _actorName( 4217 managed.name(), _otherSide(self.side) 4218 ) 4219 4220 actor = _HybridDecl(ipdl.type.ActorType(managed), "aActor") 4221 4222 # ManagedEndpoint<PThere> OpenPEndpoint(PHere* aActor) 4223 openmeth = MethodDefn( 4224 MethodDecl( 4225 "Open%sEndpoint" % managed.name(), 4226 params=[ 4227 Decl(self.protocol.managedCxxType(managed, self.side), actor.name) 4228 ], 4229 ret=Type(thereEp), 4230 ) 4231 ) 4232 openmeth.addcode( 4233 """ 4234 $*{bind} 4235 return ${thereEp}(mozilla::ipc::PrivateIPDLInterface(), aActor->Id()); 4236 """, 4237 bind=self.bindManagedActor(actor, errfn=ExprCall(ExprVar(thereEp))), 4238 thereEp=thereEp, 4239 ) 4240 4241 # void BindPEndpoint(ManagedEndpoint<PHere>&& aEndpoint, PHere* aActor) 4242 bindmeth = MethodDefn( 4243 MethodDecl( 4244 "Bind%sEndpoint" % managed.name(), 4245 params=[ 4246 Decl(Type(hereEp), "aEndpoint"), 4247 Decl(self.protocol.managedCxxType(managed, self.side), actor.name), 4248 ], 4249 ret=Type.BOOL, 4250 ) 4251 ) 4252 bindmeth.addcode( 4253 """ 4254 MOZ_RELEASE_ASSERT(aEndpoint.ActorId(), "Invalid Endpoint!"); 4255 $*{bind} 4256 return true; 4257 """, 4258 bind=self.bindManagedActor( 4259 actor, errfn=ExprLiteral.FALSE, idexpr=ExprCode("*aEndpoint.ActorId()") 4260 ), 4261 ) 4262 4263 self.cls.addstmts([openmeth, bindmeth, Whitespace.NL]) 4264 4265 def implementManagerIface(self): 4266 p = self.protocol 4267 protocolbase = Type("IProtocol", ptr=True) 4268 4269 methods = [] 4270 4271 if p.decl.type.isToplevel(): 4272 4273 # "private" message that passes shmem mappings from one process 4274 # to the other 4275 if p.subtreeUsesShmem(): 4276 self.asyncSwitch.addcase( 4277 CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"), 4278 self.genShmemCreatedHandler(), 4279 ) 4280 self.asyncSwitch.addcase( 4281 CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"), 4282 self.genShmemDestroyedHandler(), 4283 ) 4284 else: 4285 abort = StmtBlock() 4286 abort.addstmts( 4287 [ 4288 _fatalError("this protocol tree does not use shmem"), 4289 StmtReturn(_Result.NotKnown), 4290 ] 4291 ) 4292 self.asyncSwitch.addcase(CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"), abort) 4293 self.asyncSwitch.addcase( 4294 CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"), abort 4295 ) 4296 4297 # Keep track of types created with an INOUT ctor. We need to call 4298 # Register() or RegisterID() for them depending on the side the managee 4299 # is created. 4300 inoutCtorTypes = [] 4301 for msg in p.messageDecls: 4302 msgtype = msg.decl.type 4303 if msgtype.isCtor() and msgtype.isInout(): 4304 inoutCtorTypes.append(msgtype.constructedType()) 4305 4306 # all protocols share the "same" RemoveManagee() implementation 4307 pvar = ExprVar("aProtocolId") 4308 listenervar = ExprVar("aListener") 4309 removemanagee = MethodDefn( 4310 MethodDecl( 4311 p.removeManageeMethod().name, 4312 params=[ 4313 Decl(_protocolIdType(), pvar.name), 4314 Decl(protocolbase, listenervar.name), 4315 ], 4316 methodspec=MethodSpec.OVERRIDE, 4317 ) 4318 ) 4319 4320 if not len(p.managesStmts): 4321 removemanagee.addcode( 4322 """ 4323 FatalError("unreached"); 4324 return; 4325 """ 4326 ) 4327 else: 4328 switchontype = StmtSwitch(pvar) 4329 for managee in p.managesStmts: 4330 manageeipdltype = managee.decl.type 4331 manageecxxtype = _cxxBareType( 4332 ipdl.type.ActorType(manageeipdltype), self.side 4333 ) 4334 case = ExprCode( 4335 """ 4336 { 4337 ${manageecxxtype} actor = static_cast<${manageecxxtype}>(aListener); 4338 4339 const bool removed = ${container}.EnsureRemoved(actor); 4340 MOZ_RELEASE_ASSERT(removed, "actor not managed by this!"); 4341 4342 auto* proxy = actor->GetLifecycleProxy(); 4343 NS_IF_RELEASE(proxy); 4344 return; 4345 } 4346 """, 4347 manageecxxtype=manageecxxtype, 4348 container=p.managedVar(manageeipdltype, self.side), 4349 ) 4350 switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case) 4351 switchontype.addcase( 4352 DefaultLabel(), 4353 ExprCode( 4354 """ 4355 FatalError("unreached"); 4356 return; 4357 """ 4358 ), 4359 ) 4360 removemanagee.addstmt(switchontype) 4361 4362 # The `DeallocManagee` method is called for managed actors to trigger 4363 # deallocation when ActorLifecycleProxy is freed. 4364 deallocmanagee = MethodDefn( 4365 MethodDecl( 4366 p.deallocManageeMethod().name, 4367 params=[ 4368 Decl(_protocolIdType(), pvar.name), 4369 Decl(protocolbase, listenervar.name), 4370 ], 4371 methodspec=MethodSpec.OVERRIDE, 4372 ) 4373 ) 4374 4375 if not len(p.managesStmts): 4376 deallocmanagee.addcode( 4377 """ 4378 FatalError("unreached"); 4379 return; 4380 """ 4381 ) 4382 else: 4383 switchontype = StmtSwitch(pvar) 4384 for managee in p.managesStmts: 4385 manageeipdltype = managee.decl.type 4386 # Reference counted actor types don't have corresponding 4387 # `Dealloc` methods, as they are deallocated by releasing the 4388 # IPDL-held reference. 4389 if manageeipdltype.isRefcounted(): 4390 continue 4391 4392 case = StmtCode( 4393 """ 4394 ${concrete}->${dealloc}(static_cast<${type}>(aListener)); 4395 return; 4396 """, 4397 concrete=self.concreteThis(), 4398 dealloc=_deallocMethod(manageeipdltype, self.side), 4399 type=_cxxBareType(ipdl.type.ActorType(manageeipdltype), self.side), 4400 ) 4401 switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case) 4402 switchontype.addcase( 4403 DefaultLabel(), 4404 StmtCode( 4405 """ 4406 FatalError("unreached"); 4407 return; 4408 """ 4409 ), 4410 ) 4411 deallocmanagee.addstmt(switchontype) 4412 4413 return methods + [removemanagee, deallocmanagee, Whitespace.NL] 4414 4415 def genShmemCreatedHandler(self): 4416 assert self.protocol.decl.type.isToplevel() 4417 4418 return StmtCode( 4419 """ 4420 { 4421 if (!ShmemCreated(${msgvar})) { 4422 return MsgPayloadError; 4423 } 4424 return MsgProcessed; 4425 } 4426 """, 4427 msgvar=self.msgvar, 4428 ) 4429 4430 def genShmemDestroyedHandler(self): 4431 assert self.protocol.decl.type.isToplevel() 4432 4433 return StmtCode( 4434 """ 4435 { 4436 if (!ShmemDestroyed(${msgvar})) { 4437 return MsgPayloadError; 4438 } 4439 return MsgProcessed; 4440 } 4441 """, 4442 msgvar=self.msgvar, 4443 ) 4444 4445 # ------------------------------------------------------------------------- 4446 # The next few functions are the crux of the IPDL code generator. 4447 # They generate code for all the nasty work of message 4448 # serialization/deserialization and dispatching handlers for 4449 # received messages. 4450 ## 4451 4452 def concreteThis(self): 4453 if (self.protocol.name, self.side) in VIRTUAL_CALL_CLASSES: 4454 return ExprVar.THIS 4455 4456 if (self.protocol.name, self.side) in DIRECT_CALL_OVERRIDES: 4457 (class_name, _) = DIRECT_CALL_OVERRIDES[self.protocol.name, self.side] 4458 else: 4459 assert self.protocol.name.startswith("P") 4460 class_name = "{}{}".format(self.protocol.name[1:], self.side.capitalize()) 4461 4462 return ExprCode("static_cast<${class_name}*>(this)", class_name=class_name) 4463 4464 def thisCall(self, function, args): 4465 return ExprCall(ExprSelect(self.concreteThis(), "->", function), args=args) 4466 4467 def visitMessageDecl(self, md): 4468 isctor = md.decl.type.isCtor() 4469 isdtor = md.decl.type.isDtor() 4470 decltype = md.decl.type 4471 sendmethod = None 4472 movesendmethod = None 4473 promisesendmethod = None 4474 recvlbl, recvcase = None, None 4475 4476 def addRecvCase(lbl, case): 4477 if decltype.isAsync(): 4478 self.asyncSwitch.addcase(lbl, case) 4479 elif decltype.isSync(): 4480 self.syncSwitch.addcase(lbl, case) 4481 elif decltype.isInterrupt(): 4482 self.interruptSwitch.addcase(lbl, case) 4483 else: 4484 assert 0 4485 4486 if self.sendsMessage(md): 4487 isasync = decltype.isAsync() 4488 4489 # NOTE: Don't generate helper ctors for refcounted types. 4490 # 4491 # Safety concerns around providing your own actor to a ctor (namely 4492 # that the return value won't be checked, and the argument will be 4493 # `delete`-ed) are less critical with refcounted actors, due to the 4494 # actor being held alive by the callsite. 4495 # 4496 # This allows refcounted actors to not implement crashing AllocPFoo 4497 # methods on the sending side. 4498 if isctor and not md.decl.type.constructedType().isRefcounted(): 4499 self.cls.addstmts([self.genHelperCtor(md), Whitespace.NL]) 4500 4501 if isctor and isasync: 4502 sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md) 4503 elif isctor: 4504 sendmethod = self.genBlockingCtorMethod(md) 4505 elif isdtor and isasync: 4506 sendmethod, (recvlbl, recvcase) = self.genAsyncDtor(md) 4507 elif isdtor: 4508 sendmethod = self.genBlockingDtorMethod(md) 4509 elif isasync: 4510 ( 4511 sendmethod, 4512 movesendmethod, 4513 promisesendmethod, 4514 (recvlbl, recvcase), 4515 ) = self.genAsyncSendMethod(md) 4516 else: 4517 sendmethod, movesendmethod = self.genBlockingSendMethod(md) 4518 4519 # XXX figure out what to do here 4520 if isdtor and md.decl.type.constructedType().isToplevel(): 4521 sendmethod = None 4522 4523 if sendmethod is not None: 4524 self.cls.addstmts([sendmethod, Whitespace.NL]) 4525 if movesendmethod is not None: 4526 self.cls.addstmts([movesendmethod, Whitespace.NL]) 4527 if promisesendmethod is not None: 4528 self.cls.addstmts([promisesendmethod, Whitespace.NL]) 4529 if recvcase is not None: 4530 addRecvCase(recvlbl, recvcase) 4531 recvlbl, recvcase = None, None 4532 4533 if self.receivesMessage(md): 4534 if isctor: 4535 recvlbl, recvcase = self.genCtorRecvCase(md) 4536 elif isdtor: 4537 recvlbl, recvcase = self.genDtorRecvCase(md) 4538 else: 4539 recvlbl, recvcase = self.genRecvCase(md) 4540 4541 # XXX figure out what to do here 4542 if isdtor and md.decl.type.constructedType().isToplevel(): 4543 return 4544 4545 addRecvCase(recvlbl, recvcase) 4546 4547 def genAsyncCtor(self, md): 4548 actor = md.actorDecl() 4549 method = MethodDefn(self.makeSendMethodDecl(md)) 4550 4551 msgvar, stmts = self.makeMessage(md, errfnSendCtor) 4552 sendok, sendstmts = self.sendAsync(md, msgvar) 4553 4554 method.addcode( 4555 """ 4556 $*{bind} 4557 4558 // Build our constructor message. 4559 $*{stmts} 4560 4561 // Notify the other side about the newly created actor. This can 4562 // fail if our manager has already been destroyed. 4563 // 4564 // NOTE: If the send call fails due to toplevel channel teardown, 4565 // the `IProtocol::ChannelSend` wrapper absorbs the error for us, 4566 // so we don't tear down actors unexpectedly. 4567 $*{sendstmts} 4568 4569 // Warn, destroy the actor, and return null if the message failed to 4570 // send. Otherwise, return the successfully created actor reference. 4571 if (!${sendok}) { 4572 NS_WARNING("Error sending ${actorname} constructor"); 4573 $*{destroy} 4574 return nullptr; 4575 } 4576 return ${actor}; 4577 """, 4578 bind=self.bindManagedActor(actor), 4579 stmts=stmts, 4580 sendstmts=sendstmts, 4581 sendok=sendok, 4582 destroy=self.destroyActor( 4583 md, actor.var(), why=_DestroyReason.FailedConstructor 4584 ), 4585 actor=actor.var(), 4586 actorname=actor.ipdltype.protocol.name() + self.side.capitalize(), 4587 ) 4588 4589 lbl = CaseLabel(md.pqReplyId()) 4590 case = StmtBlock() 4591 case.addstmt(StmtReturn(_Result.Processed)) 4592 # TODO not really sure what to do with async ctor "replies" yet. 4593 # destroy actor if there was an error? tricky ... 4594 4595 return method, (lbl, case) 4596 4597 def genBlockingCtorMethod(self, md): 4598 actor = md.actorDecl() 4599 method = MethodDefn(self.makeSendMethodDecl(md)) 4600 4601 msgvar, stmts = self.makeMessage(md, errfnSendCtor) 4602 4603 replyvar = self.replyvar 4604 sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) 4605 replystmts = self.deserializeReply( 4606 md, 4607 ExprAddrOf(replyvar), 4608 self.side, 4609 errfnSendCtor, 4610 errfnSentinel(ExprLiteral.NULL), 4611 ) 4612 4613 method.addcode( 4614 """ 4615 $*{bind} 4616 4617 // Build our constructor message. 4618 $*{stmts} 4619 4620 // Synchronously send the constructor message to the other side. If 4621 // the send fails, e.g. due to the remote side shutting down, the 4622 // actor will be destroyed and potentially freed. 4623 Message ${replyvar}; 4624 $*{sendstmts} 4625 4626 if (!(${sendok})) { 4627 // Warn, destroy the actor and return null if the message 4628 // failed to send. 4629 NS_WARNING("Error sending constructor"); 4630 $*{destroy} 4631 return nullptr; 4632 } 4633 4634 $*{replystmts} 4635 return ${actor}; 4636 """, 4637 bind=self.bindManagedActor(actor), 4638 stmts=stmts, 4639 replyvar=replyvar, 4640 sendstmts=sendstmts, 4641 sendok=sendok, 4642 destroy=self.destroyActor( 4643 md, actor.var(), why=_DestroyReason.FailedConstructor 4644 ), 4645 replystmts=replystmts, 4646 actor=actor.var(), 4647 actorname=actor.ipdltype.protocol.name() + self.side.capitalize(), 4648 ) 4649 4650 return method 4651 4652 def bindManagedActor(self, actordecl, errfn=ExprLiteral.NULL, idexpr=None): 4653 actorproto = actordecl.ipdltype.protocol 4654 4655 if idexpr is None: 4656 setManagerArgs = [ExprVar.THIS] 4657 else: 4658 setManagerArgs = [ExprVar.THIS, idexpr] 4659 4660 return [ 4661 StmtCode( 4662 """ 4663 if (!${actor}) { 4664 NS_WARNING("Cannot bind null ${actorname} actor"); 4665 return ${errfn}; 4666 } 4667 4668 ${actor}->SetManagerAndRegister($,{setManagerArgs}); 4669 ${container}.Insert(${actor}); 4670 """, 4671 actor=actordecl.var(), 4672 actorname=actorproto.name() + self.side.capitalize(), 4673 errfn=errfn, 4674 setManagerArgs=setManagerArgs, 4675 container=self.protocol.managedVar(actorproto, self.side), 4676 ) 4677 ] 4678 4679 def genHelperCtor(self, md): 4680 helperdecl = self.makeSendMethodDecl(md) 4681 helperdecl.params = helperdecl.params[1:] 4682 helper = MethodDefn(helperdecl) 4683 4684 helper.addstmts( 4685 [ 4686 self.callAllocActor(md, retsems="out", side=self.side), 4687 StmtReturn(ExprCall(ExprVar(helperdecl.name), args=md.makeCxxArgs())), 4688 ] 4689 ) 4690 return helper 4691 4692 def genAsyncDtor(self, md): 4693 actor = md.actorDecl() 4694 actorvar = actor.var() 4695 method = MethodDefn(self.makeDtorMethodDecl(md)) 4696 4697 method.addstmt(self.dtorPrologue(actorvar)) 4698 4699 msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) 4700 sendok, sendstmts = self.sendAsync(md, msgvar, actorvar) 4701 method.addstmts( 4702 stmts 4703 + sendstmts 4704 + [Whitespace.NL] 4705 + self.dtorEpilogue(md, actor.var()) 4706 + [StmtReturn(sendok)] 4707 ) 4708 4709 lbl = CaseLabel(md.pqReplyId()) 4710 case = StmtBlock() 4711 case.addstmt(StmtReturn(_Result.Processed)) 4712 # TODO if the dtor is "inherently racy", keep the actor alive 4713 # until the other side acks 4714 4715 return method, (lbl, case) 4716 4717 def genBlockingDtorMethod(self, md): 4718 actor = md.actorDecl() 4719 actorvar = actor.var() 4720 method = MethodDefn(self.makeDtorMethodDecl(md)) 4721 4722 method.addstmt(self.dtorPrologue(actorvar)) 4723 4724 msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) 4725 4726 replyvar = self.replyvar 4727 sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar, actorvar) 4728 method.addstmts( 4729 stmts 4730 + [Whitespace.NL, StmtDecl(Decl(Type("Message"), replyvar.name))] 4731 + sendstmts 4732 ) 4733 4734 destmts = self.deserializeReply( 4735 md, ExprAddrOf(replyvar), self.side, errfnSend, errfnSentinel(), actorvar 4736 ) 4737 ifsendok = StmtIf(ExprLiteral.FALSE) 4738 ifsendok.addifstmts(destmts) 4739 ifsendok.addifstmts( 4740 [Whitespace.NL, StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, "&="))] 4741 ) 4742 4743 method.addstmt(ifsendok) 4744 4745 method.addstmts( 4746 self.dtorEpilogue(md, actor.var()) + [Whitespace.NL, StmtReturn(sendok)] 4747 ) 4748 4749 return method 4750 4751 def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion): 4752 if md.decl.type.isCtor(): 4753 destroyedType = md.decl.type.constructedType() 4754 else: 4755 destroyedType = self.protocol.decl.type 4756 4757 return [ 4758 StmtCode( 4759 """ 4760 IProtocol* mgr = ${actor}->Manager(); 4761 ${actor}->DestroySubtree(${why}); 4762 ${actor}->ClearSubtree(); 4763 mgr->RemoveManagee(${protoId}, ${actor}); 4764 """, 4765 actor=actorexpr, 4766 why=why, 4767 protoId=_protocolId(destroyedType), 4768 ) 4769 ] 4770 4771 def dtorPrologue(self, actorexpr): 4772 return StmtCode( 4773 """ 4774 if (!${actor} || !${actor}->CanSend()) { 4775 NS_WARNING("Attempt to __delete__ missing or closed actor"); 4776 return false; 4777 } 4778 """, 4779 actor=actorexpr, 4780 ) 4781 4782 def dtorEpilogue(self, md, actorexpr): 4783 return self.destroyActor(md, actorexpr) 4784 4785 def genRecvAsyncReplyCase(self, md): 4786 lbl = CaseLabel(md.pqReplyId()) 4787 case = StmtBlock() 4788 resolve, reason, prologue, desrej, desstmts = self.deserializeAsyncReply( 4789 md, self.side, errfnRecv, errfnSentinel(_Result.ValuError) 4790 ) 4791 4792 if len(md.returns) > 1: 4793 resolvetype = _tuple([d.bareType(self.side) for d in md.returns]) 4794 resolvearg = ExprCall( 4795 ExprVar("MakeTuple"), args=[ExprMove(p.var()) for p in md.returns] 4796 ) 4797 else: 4798 resolvetype = md.returns[0].bareType(self.side) 4799 resolvearg = ExprMove(md.returns[0].var()) 4800 4801 case.addcode( 4802 """ 4803 $*{prologue} 4804 4805 UniquePtr<MessageChannel::UntypedCallbackHolder> untypedCallback = 4806 GetIPCChannel()->PopCallback(${msgvar}); 4807 4808 typedef MessageChannel::CallbackHolder<${resolvetype}> CallbackHolder; 4809 auto* callback = static_cast<CallbackHolder*>(untypedCallback.get()); 4810 if (!callback) { 4811 FatalError("Error unknown callback"); 4812 return MsgProcessingError; 4813 } 4814 4815 if (${resolve}) { 4816 $*{desstmts} 4817 callback->Resolve(${resolvearg}); 4818 } else { 4819 $*{desrej} 4820 callback->Reject(std::move(${reason})); 4821 } 4822 return MsgProcessed; 4823 """, 4824 prologue=prologue, 4825 msgvar=self.msgvar, 4826 resolve=resolve, 4827 resolvetype=resolvetype, 4828 desstmts=desstmts, 4829 resolvearg=resolvearg, 4830 desrej=desrej, 4831 reason=reason, 4832 ) 4833 4834 return (lbl, case) 4835 4836 def genAsyncSendMethod(self, md): 4837 method = MethodDefn(self.makeSendMethodDecl(md)) 4838 msgvar, stmts = self.makeMessage(md, errfnSend) 4839 retvar, sendstmts = self.sendAsync(md, msgvar) 4840 4841 method.addstmts(stmts + [Whitespace.NL] + sendstmts + [StmtReturn(retvar)]) 4842 4843 movemethod = None 4844 4845 # Add the promise overload if we need one. 4846 if md.returns: 4847 promisemethod = MethodDefn(self.makeSendMethodDecl(md, promise=True)) 4848 stmts = self.sendAsyncWithPromise(md) 4849 promisemethod.addstmts(stmts) 4850 4851 (lbl, case) = self.genRecvAsyncReplyCase(md) 4852 else: 4853 (promisemethod, lbl, case) = (None, None, None) 4854 4855 return method, movemethod, promisemethod, (lbl, case) 4856 4857 def genBlockingSendMethod(self, md, fromActor=None): 4858 method = MethodDefn(self.makeSendMethodDecl(md)) 4859 4860 msgvar, serstmts = self.makeMessage(md, errfnSend, fromActor) 4861 replyvar = self.replyvar 4862 4863 sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) 4864 failif = StmtIf(ExprNot(sendok)) 4865 failif.addifstmt(StmtReturn.FALSE) 4866 4867 desstmts = self.deserializeReply( 4868 md, ExprAddrOf(replyvar), self.side, errfnSend, errfnSentinel() 4869 ) 4870 4871 method.addstmts( 4872 serstmts 4873 + [Whitespace.NL, StmtDecl(Decl(Type("Message"), replyvar.name))] 4874 + sendstmts 4875 + [failif] 4876 + desstmts 4877 + [Whitespace.NL, StmtReturn.TRUE] 4878 ) 4879 4880 movemethod = None 4881 4882 return method, movemethod 4883 4884 def genCtorRecvCase(self, md): 4885 lbl = CaseLabel(md.pqMsgId()) 4886 case = StmtBlock() 4887 actorhandle = self.handlevar 4888 4889 stmts = self.deserializeMessage( 4890 md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) 4891 ) 4892 4893 idvar, saveIdStmts = self.saveActorId(md) 4894 case.addstmts( 4895 stmts 4896 + [ 4897 StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) 4898 for r in md.returns 4899 ] 4900 # alloc the actor, register it under the foreign ID 4901 + [self.callAllocActor(md, retsems="in", side=self.side)] 4902 + self.bindManagedActor( 4903 md.actorDecl(), errfn=_Result.ValuError, idexpr=_actorHId(actorhandle) 4904 ) 4905 + [Whitespace.NL] 4906 + saveIdStmts 4907 + self.invokeRecvHandler(md) 4908 + self.makeReply(md, errfnRecv, idvar) 4909 + [Whitespace.NL, StmtReturn(_Result.Processed)] 4910 ) 4911 4912 return lbl, case 4913 4914 def genDtorRecvCase(self, md): 4915 lbl = CaseLabel(md.pqMsgId()) 4916 case = StmtBlock() 4917 4918 stmts = self.deserializeMessage( 4919 md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) 4920 ) 4921 4922 idvar, saveIdStmts = self.saveActorId(md) 4923 case.addstmts( 4924 stmts 4925 + [ 4926 StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) 4927 for r in md.returns 4928 ] 4929 + self.invokeRecvHandler(md, implicit=False) 4930 + [Whitespace.NL] 4931 + saveIdStmts 4932 + self.makeReply(md, errfnRecv, routingId=idvar) 4933 + [Whitespace.NL] 4934 + self.dtorEpilogue(md, md.actorDecl().var()) 4935 + [Whitespace.NL, StmtReturn(_Result.Processed)] 4936 ) 4937 4938 return lbl, case 4939 4940 def genRecvCase(self, md): 4941 lbl = CaseLabel(md.pqMsgId()) 4942 case = StmtBlock() 4943 4944 stmts = self.deserializeMessage( 4945 md, self.side, errfn=errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) 4946 ) 4947 4948 idvar, saveIdStmts = self.saveActorId(md) 4949 declstmts = [ 4950 StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) 4951 for r in md.returns 4952 ] 4953 if md.decl.type.isAsync() and md.returns: 4954 declstmts = self.makeResolver(md, errfnRecv, routingId=idvar) 4955 case.addstmts( 4956 stmts 4957 + saveIdStmts 4958 + declstmts 4959 + self.invokeRecvHandler(md) 4960 + [Whitespace.NL] 4961 + self.makeReply(md, errfnRecv, routingId=idvar) 4962 + [StmtReturn(_Result.Processed)] 4963 ) 4964 4965 return lbl, case 4966 4967 # helper methods 4968 4969 def makeMessage(self, md, errfn, fromActor=None): 4970 msgvar = self.msgvar 4971 routingId = self.protocol.routingId(fromActor) 4972 this = ExprVar.THIS 4973 if md.decl.type.isDtor(): 4974 this = md.actorDecl().var() 4975 4976 stmts = ( 4977 [ 4978 StmtDecl( 4979 Decl(Type("IPC::Message", ptr=True), msgvar.name), 4980 init=ExprCall(ExprVar(md.pqMsgCtorFunc()), args=[routingId]), 4981 ) 4982 ] 4983 + [Whitespace.NL] 4984 + [ 4985 _ParamTraits.checkedWrite( 4986 p.ipdltype, p.var(), msgvar, sentinelKey=p.name, actor=this 4987 ) 4988 for p in md.params 4989 ] 4990 + [Whitespace.NL] 4991 + self.setMessageFlags(md, msgvar) 4992 ) 4993 return msgvar, stmts 4994 4995 def makeResolver(self, md, errfn, routingId): 4996 if routingId is None: 4997 routingId = self.protocol.routingId() 4998 if not md.decl.type.isAsync() or not md.hasReply(): 4999 return [] 5000 5001 def paramValue(idx): 5002 assert idx < len(md.returns) 5003 if len(md.returns) > 1: 5004 return ExprCode("mozilla::Get<${idx}>(aParam)", idx=idx) 5005 return ExprVar("aParam") 5006 5007 serializeParams = [ 5008 _ParamTraits.checkedWrite( 5009 p.ipdltype, 5010 paramValue(idx), 5011 self.replyvar, 5012 sentinelKey=p.name, 5013 actor=ExprVar("self__"), 5014 ) 5015 for idx, p in enumerate(md.returns) 5016 ] 5017 5018 return [ 5019 StmtCode( 5020 """ 5021 UniquePtr<IPC::Message> ${replyvar}(${replyCtor}(${routingId})); 5022 ${replyvar}->set_seqno(${msgvar}.seqno()); 5023 5024 RefPtr<mozilla::ipc::IPDLResolverInner> resolver__ = 5025 new mozilla::ipc::IPDLResolverInner(std::move(${replyvar}), this); 5026 5027 ${resolvertype} resolver = [resolver__ = std::move(resolver__)](${resolveType} aParam) { 5028 resolver__->Resolve([&] (IPC::Message* ${replyvar}, IProtocol* self__) { 5029 $*{serializeParams} 5030 ${logSendingReply} 5031 }); 5032 }; 5033 """, 5034 msgvar=self.msgvar, 5035 resolvertype=Type(md.resolverName()), 5036 routingId=routingId, 5037 resolveType=_resolveType(md.returns, self.side), 5038 replyvar=self.replyvar, 5039 replyCtor=ExprVar(md.pqReplyCtorFunc()), 5040 serializeParams=serializeParams, 5041 logSendingReply=self.logMessage( 5042 md, 5043 self.replyvar, 5044 "Sending reply ", 5045 actor=ExprVar("self__"), 5046 ), 5047 ) 5048 ] 5049 5050 def makeReply(self, md, errfn, routingId): 5051 if routingId is None: 5052 routingId = self.protocol.routingId() 5053 # TODO special cases for async ctor/dtor replies 5054 if not md.decl.type.hasReply(): 5055 return [] 5056 if md.decl.type.isAsync() and md.decl.type.hasReply(): 5057 return [] 5058 5059 replyvar = self.replyvar 5060 return ( 5061 [ 5062 StmtExpr( 5063 ExprAssn( 5064 replyvar, 5065 ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[routingId]), 5066 ) 5067 ), 5068 Whitespace.NL, 5069 ] 5070 + [ 5071 _ParamTraits.checkedWrite( 5072 r.ipdltype, 5073 r.var(), 5074 replyvar, 5075 sentinelKey=r.name, 5076 actor=ExprVar.THIS, 5077 ) 5078 for r in md.returns 5079 ] 5080 + self.setMessageFlags(md, replyvar) 5081 + [self.logMessage(md, replyvar, "Sending reply ")] 5082 ) 5083 5084 def setMessageFlags(self, md, var, seqno=None): 5085 stmts = [] 5086 5087 if seqno: 5088 stmts.append( 5089 StmtExpr(ExprCall(ExprSelect(var, "->", "set_seqno"), args=[seqno])) 5090 ) 5091 5092 return stmts + [Whitespace.NL] 5093 5094 def deserializeMessage(self, md, side, errfn, errfnSent): 5095 msgvar = self.msgvar 5096 itervar = self.itervar 5097 msgexpr = ExprAddrOf(msgvar) 5098 isctor = md.decl.type.isCtor() 5099 stmts = [ 5100 self.logMessage(md, msgexpr, "Received ", receiving=True), 5101 self.profilerLabel(md), 5102 Whitespace.NL, 5103 ] 5104 5105 if 0 == len(md.params): 5106 return stmts 5107 5108 start, decls, reads = 0, [], [] 5109 if isctor: 5110 # return the raw actor handle so that its ID can be used 5111 # to construct the "real" actor 5112 handlevar = self.handlevar 5113 handletype = Type("ActorHandle") 5114 decls = [StmtDecl(Decl(handletype, handlevar.name), initargs=[])] 5115 reads = [ 5116 _ParamTraits.checkedRead( 5117 None, 5118 ExprAddrOf(handlevar), 5119 msgexpr, 5120 ExprAddrOf(self.itervar), 5121 errfn, 5122 "'%s'" % handletype.name, 5123 sentinelKey="actor", 5124 errfnSentinel=errfnSent, 5125 actor=ExprVar.THIS, 5126 ) 5127 ] 5128 start = 1 5129 5130 decls.extend( 5131 [ 5132 StmtDecl( 5133 Decl( 5134 ( 5135 Type("Tainted", T=p.bareType(side)) 5136 if md.decl.type.tainted and "NoTaint" not in p.attributes 5137 else p.bareType(side) 5138 ), 5139 p.var().name, 5140 ), 5141 initargs=[], 5142 ) 5143 for p in md.params[start:] 5144 ] 5145 ) 5146 reads.extend( 5147 [ 5148 _ParamTraits.checkedRead( 5149 p.ipdltype, 5150 ExprAddrOf(p.var()), 5151 msgexpr, 5152 ExprAddrOf(itervar), 5153 errfn, 5154 "'%s'" % p.ipdltype.name(), 5155 sentinelKey=p.name, 5156 errfnSentinel=errfnSent, 5157 actor=ExprVar.THIS, 5158 ) 5159 for p in md.params[start:] 5160 ] 5161 ) 5162 5163 stmts.extend( 5164 ( 5165 [ 5166 StmtDecl( 5167 Decl(_iterType(ptr=False), self.itervar.name), initargs=[msgvar] 5168 ) 5169 ] 5170 + decls 5171 + [Whitespace.NL] 5172 + reads 5173 + [self.endRead(msgvar, itervar)] 5174 ) 5175 ) 5176 5177 return stmts 5178 5179 def deserializeAsyncReply(self, md, side, errfn, errfnSent): 5180 msgvar = self.msgvar 5181 itervar = self.itervar 5182 msgexpr = ExprAddrOf(msgvar) 5183 isctor = md.decl.type.isCtor() 5184 resolve = ExprVar("resolve__") 5185 reason = ExprVar("reason__") 5186 5187 # NOTE: The `resolve__` and `reason__` parameters don't have sentinels, 5188 # as they are serialized by the IPDLResolverInner type in 5189 # ProtocolUtils.cpp rather than by generated code. 5190 desresolve = [ 5191 StmtCode( 5192 """ 5193 bool resolve__ = false; 5194 if (!ReadIPDLParam(${msgexpr}, &${itervar}, this, &resolve__)) { 5195 FatalError("Error deserializing bool"); 5196 return MsgValueError; 5197 } 5198 """, 5199 msgexpr=msgexpr, 5200 itervar=itervar, 5201 ), 5202 ] 5203 desrej = [ 5204 StmtCode( 5205 """ 5206 ResponseRejectReason reason__{}; 5207 if (!ReadIPDLParam(${msgexpr}, &${itervar}, this, &reason__)) { 5208 FatalError("Error deserializing ResponseRejectReason"); 5209 return MsgValueError; 5210 } 5211 """, 5212 msgexpr=msgexpr, 5213 itervar=itervar, 5214 ), 5215 self.endRead(msgvar, itervar), 5216 ] 5217 prologue = [ 5218 self.logMessage(md, msgexpr, "Received ", receiving=True), 5219 self.profilerLabel(md), 5220 Whitespace.NL, 5221 ] 5222 5223 if not md.returns: 5224 return prologue 5225 5226 prologue.extend( 5227 [StmtDecl(Decl(_iterType(ptr=False), itervar.name), initargs=[msgvar])] 5228 + desresolve 5229 ) 5230 5231 start, decls, reads = 0, [], [] 5232 if isctor: 5233 # return the raw actor handle so that its ID can be used 5234 # to construct the "real" actor 5235 handlevar = self.handlevar 5236 handletype = Type("ActorHandle") 5237 decls = [StmtDecl(Decl(handletype, handlevar.name), initargs=[])] 5238 reads = [ 5239 _ParamTraits.checkedRead( 5240 None, 5241 ExprAddrOf(handlevar), 5242 msgexpr, 5243 ExprAddrOf(itervar), 5244 errfn, 5245 "'%s'" % handletype.name, 5246 sentinelKey="actor", 5247 errfnSentinel=errfnSent, 5248 actor=ExprVar.THIS, 5249 ) 5250 ] 5251 start = 1 5252 5253 stmts = ( 5254 decls 5255 + [ 5256 StmtDecl(Decl(p.bareType(side), p.var().name), initargs=[]) 5257 for p in md.returns 5258 ] 5259 + [Whitespace.NL] 5260 + reads 5261 + [ 5262 _ParamTraits.checkedRead( 5263 p.ipdltype, 5264 ExprAddrOf(p.var()), 5265 msgexpr, 5266 ExprAddrOf(itervar), 5267 errfn, 5268 "'%s'" % p.ipdltype.name(), 5269 sentinelKey=p.name, 5270 errfnSentinel=errfnSent, 5271 actor=ExprVar.THIS, 5272 ) 5273 for p in md.returns[start:] 5274 ] 5275 + [self.endRead(msgvar, itervar)] 5276 ) 5277 5278 return resolve, reason, prologue, desrej, stmts 5279 5280 def deserializeReply( 5281 self, md, replyexpr, side, errfn, errfnSentinel, actor=None, decls=False 5282 ): 5283 stmts = [ 5284 Whitespace.NL, 5285 self.logMessage(md, replyexpr, "Received reply ", actor, receiving=True), 5286 ] 5287 if 0 == len(md.returns): 5288 return stmts 5289 5290 itervar = self.itervar 5291 declstmts = [] 5292 if decls: 5293 declstmts = [ 5294 StmtDecl(Decl(p.bareType(side), p.var().name), initargs=[]) 5295 for p in md.returns 5296 ] 5297 stmts.extend( 5298 [ 5299 Whitespace.NL, 5300 StmtDecl( 5301 Decl(_iterType(ptr=False), itervar.name), initargs=[self.replyvar] 5302 ), 5303 ] 5304 + declstmts 5305 + [Whitespace.NL] 5306 + [ 5307 _ParamTraits.checkedRead( 5308 r.ipdltype, 5309 r.var(), 5310 ExprAddrOf(self.replyvar), 5311 ExprAddrOf(self.itervar), 5312 errfn, 5313 "'%s'" % r.ipdltype.name(), 5314 sentinelKey=r.name, 5315 errfnSentinel=errfnSentinel, 5316 actor=ExprVar.THIS, 5317 ) 5318 for r in md.returns 5319 ] 5320 + [self.endRead(self.replyvar, itervar)] 5321 ) 5322 5323 return stmts 5324 5325 def sendAsync(self, md, msgexpr, actor=None): 5326 sendok = ExprVar("sendok__") 5327 resolvefn = ExprVar("aResolve") 5328 rejectfn = ExprVar("aReject") 5329 5330 stmts = [ 5331 Whitespace.NL, 5332 self.logMessage(md, msgexpr, "Sending ", actor), 5333 self.profilerLabel(md), 5334 ] 5335 stmts.append(Whitespace.NL) 5336 5337 # Generate the actual call expression. 5338 send = ExprVar("ChannelSend") 5339 if actor is not None: 5340 send = ExprSelect(actor, "->", send.name) 5341 if md.returns: 5342 stmts.append( 5343 StmtExpr( 5344 ExprCall( 5345 send, args=[msgexpr, ExprMove(resolvefn), ExprMove(rejectfn)] 5346 ) 5347 ) 5348 ) 5349 retvar = None 5350 else: 5351 stmts.append( 5352 StmtDecl( 5353 Decl(Type.BOOL, sendok.name), init=ExprCall(send, args=[msgexpr]) 5354 ) 5355 ) 5356 retvar = sendok 5357 5358 return (retvar, stmts) 5359 5360 def sendBlocking(self, md, msgexpr, replyexpr, actor=None): 5361 send = ExprVar("ChannelSend") 5362 if md.decl.type.isInterrupt(): 5363 send = ExprVar("ChannelCall") 5364 if actor is not None: 5365 send = ExprSelect(actor, "->", send.name) 5366 5367 sendok = ExprVar("sendok__") 5368 self.externalIncludes.add("mozilla/ProfilerMarkers.h") 5369 return ( 5370 sendok, 5371 ( 5372 [ 5373 Whitespace.NL, 5374 self.logMessage(md, msgexpr, "Sending ", actor), 5375 self.profilerLabel(md), 5376 ] 5377 + [ 5378 Whitespace.NL, 5379 StmtDecl(Decl(Type.BOOL, sendok.name), init=ExprLiteral.FALSE), 5380 StmtBlock( 5381 [ 5382 StmtExpr( 5383 ExprCall( 5384 ExprVar("AUTO_PROFILER_TRACING_MARKER"), 5385 [ 5386 ExprLiteral.String("Sync IPC"), 5387 ExprLiteral.String( 5388 self.protocol.name 5389 + "::" 5390 + md.prettyMsgName() 5391 ), 5392 ExprVar("IPC"), 5393 ], 5394 ) 5395 ), 5396 StmtExpr( 5397 ExprAssn( 5398 sendok, 5399 ExprCall( 5400 send, args=[msgexpr, ExprAddrOf(replyexpr)] 5401 ), 5402 ) 5403 ), 5404 ] 5405 ), 5406 ] 5407 ), 5408 ) 5409 5410 def sendAsyncWithPromise(self, md): 5411 # Create a new promise, and forward to the callback send overload. 5412 promise = _makePromise(md.returns, self.side, resolver=True) 5413 5414 if len(md.returns) > 1: 5415 resolvetype = _tuple([d.bareType(self.side) for d in md.returns]) 5416 else: 5417 resolvetype = md.returns[0].bareType(self.side) 5418 5419 resolve = ExprCode( 5420 """ 5421 [promise__](${resolvetype}&& aValue) { 5422 promise__->Resolve(std::move(aValue), __func__); 5423 } 5424 """, 5425 resolvetype=resolvetype, 5426 ) 5427 reject = ExprCode( 5428 """ 5429 [promise__](ResponseRejectReason&& aReason) { 5430 promise__->Reject(std::move(aReason), __func__); 5431 } 5432 """, 5433 resolvetype=resolvetype, 5434 ) 5435 5436 args = [ExprMove(p.var()) for p in md.params] + [resolve, reject] 5437 stmt = StmtCode( 5438 """ 5439 RefPtr<${promise}> promise__ = new ${promise}(__func__); 5440 promise__->UseDirectTaskDispatch(__func__); 5441 ${send}($,{args}); 5442 return promise__; 5443 """, 5444 promise=promise, 5445 send=md.sendMethod(), 5446 args=args, 5447 ) 5448 return [stmt] 5449 5450 def callAllocActor(self, md, retsems, side): 5451 actortype = md.actorDecl().bareType(self.side) 5452 if md.decl.type.constructedType().isRefcounted(): 5453 actortype.ptr = False 5454 actortype = _refptr(actortype) 5455 5456 callalloc = self.thisCall( 5457 _allocMethod(md.decl.type.constructedType(), side), 5458 args=md.makeCxxArgs(retsems=retsems, retcallsems="out", implicit=False), 5459 ) 5460 5461 return StmtDecl(Decl(actortype, md.actorDecl().var().name), init=callalloc) 5462 5463 def invokeRecvHandler(self, md, implicit=True): 5464 retsems = "in" 5465 if md.decl.type.isAsync() and md.returns: 5466 retsems = "resolver" 5467 failif = StmtIf( 5468 ExprNot( 5469 self.thisCall( 5470 md.recvMethod(), 5471 md.makeCxxArgs( 5472 paramsems="move", 5473 retsems=retsems, 5474 retcallsems="out", 5475 implicit=implicit, 5476 ), 5477 ) 5478 ) 5479 ) 5480 failif.addifstmts( 5481 [ 5482 _protocolErrorBreakpoint("Handler returned error code!"), 5483 Whitespace( 5484 "// Error handled in mozilla::ipc::IPCResult\n", indent=True 5485 ), 5486 StmtReturn(_Result.ProcessingError), 5487 ] 5488 ) 5489 return [failif] 5490 5491 def makeDtorMethodDecl(self, md): 5492 decl = self.makeSendMethodDecl(md) 5493 decl.methodspec = MethodSpec.STATIC 5494 return decl 5495 5496 def makeSendMethodDecl(self, md, promise=False, paramsems="in"): 5497 implicit = md.decl.type.hasImplicitActorParam() 5498 if md.decl.type.isAsync() and md.returns: 5499 if promise: 5500 returnsems = "promise" 5501 rettype = _refptr(Type(md.promiseName())) 5502 else: 5503 returnsems = "callback" 5504 rettype = Type.VOID 5505 else: 5506 assert not promise 5507 returnsems = "out" 5508 rettype = Type.BOOL 5509 decl = MethodDecl( 5510 md.sendMethod(), 5511 params=md.makeCxxParams( 5512 paramsems, 5513 returnsems=returnsems, 5514 side=self.side, 5515 implicit=implicit, 5516 direction="send", 5517 ), 5518 warn_unused=( 5519 (self.side == "parent" and returnsems != "callback") 5520 or (md.decl.type.isCtor() and not md.decl.type.isAsync()) 5521 ), 5522 ret=rettype, 5523 ) 5524 if md.decl.type.isCtor(): 5525 decl.ret = md.actorDecl().bareType(self.side) 5526 return decl 5527 5528 def logMessage(self, md, msgptr, pfx, actor=None, receiving=False): 5529 actorname = _actorName(self.protocol.name, self.side) 5530 return StmtCode( 5531 """ 5532 if (mozilla::ipc::LoggingEnabledFor(${actorname})) { 5533 mozilla::ipc::LogMessageForProtocol( 5534 ${actorname}, 5535 ${otherpid}, 5536 ${pfx}, 5537 ${msgptr}->type(), 5538 mozilla::ipc::MessageDirection::${direction}); 5539 } 5540 """, 5541 actorname=ExprLiteral.String(actorname), 5542 otherpid=self.protocol.callOtherPid(actor), 5543 pfx=ExprLiteral.String(pfx), 5544 msgptr=msgptr, 5545 direction="eReceiving" if receiving else "eSending", 5546 ) 5547 5548 def profilerLabel(self, md): 5549 self.externalIncludes.add("mozilla/ProfilerLabels.h") 5550 return StmtCode( 5551 """ 5552 AUTO_PROFILER_LABEL("${name}::${msgname}", OTHER); 5553 """, 5554 name=self.protocol.name, 5555 msgname=md.prettyMsgName(), 5556 ) 5557 5558 def saveActorId(self, md): 5559 idvar = ExprVar("id__") 5560 if md.decl.type.hasReply(): 5561 # only save the ID if we're actually going to use it, to 5562 # avoid unused-variable warnings 5563 saveIdStmts = [ 5564 StmtDecl(Decl(_actorIdType(), idvar.name), self.protocol.routingId()) 5565 ] 5566 else: 5567 saveIdStmts = [] 5568 return idvar, saveIdStmts 5569 5570 def endRead(self, msgexpr, iterexpr): 5571 return StmtCode( 5572 """ 5573 ${msg}.EndRead(${iter}, ${msg}.type()); 5574 """, 5575 msg=msgexpr, 5576 iter=iterexpr, 5577 ) 5578 5579 5580class _GenerateProtocolParentCode(_GenerateProtocolActorCode): 5581 def __init__(self): 5582 _GenerateProtocolActorCode.__init__(self, "parent") 5583 5584 def sendsMessage(self, md): 5585 return not md.decl.type.isIn() 5586 5587 def receivesMessage(self, md): 5588 return md.decl.type.isInout() or md.decl.type.isIn() 5589 5590 5591class _GenerateProtocolChildCode(_GenerateProtocolActorCode): 5592 def __init__(self): 5593 _GenerateProtocolActorCode.__init__(self, "child") 5594 5595 def sendsMessage(self, md): 5596 return not md.decl.type.isOut() 5597 5598 def receivesMessage(self, md): 5599 return md.decl.type.isInout() or md.decl.type.isOut() 5600 5601 5602# ----------------------------------------------------------------------------- 5603# Utility passes 5604## 5605 5606 5607def _splitClassDeclDefn(cls): 5608 """Destructively split |cls| methods into declarations and 5609 definitions (if |not methodDecl.force_inline|). Return classDecl, 5610 methodDefns.""" 5611 defns = Block() 5612 5613 for i, stmt in enumerate(cls.stmts): 5614 if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline: 5615 decl, defn = _splitMethodDeclDefn(stmt, cls) 5616 cls.stmts[i] = StmtDecl(decl) 5617 if defn: 5618 defns.addstmts([defn, Whitespace.NL]) 5619 5620 return cls, defns 5621 5622 5623def _splitMethodDeclDefn(md, cls): 5624 # Pure methods have decls but no defns. 5625 if md.decl.methodspec == MethodSpec.PURE: 5626 return md.decl, None 5627 5628 saveddecl = deepcopy(md.decl) 5629 md.decl.cls = cls 5630 # Don't emit method specifiers on method defns. 5631 md.decl.methodspec = MethodSpec.NONE 5632 md.decl.warn_unused = False 5633 md.decl.only_for_definition = True 5634 for param in md.decl.params: 5635 if isinstance(param, Param): 5636 param.default = None 5637 return saveddecl, md 5638 5639 5640def _splitFuncDeclDefn(fun): 5641 assert not fun.decl.force_inline 5642 return StmtDecl(fun.decl), fun 5643