1 //===--- TransBlockObjCVariable.cpp - Transformations to ARC mode ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // rewriteBlockObjCVariable:
10 //
11 // Adding __block to an obj-c variable could be either because the variable
12 // is used for output storage or the user wanted to break a retain cycle.
13 // This transformation checks whether a reference of the variable for the block
14 // is actually needed (it is assigned to or its address is taken) or not.
15 // If the reference is not needed it will assume __block was added to break a
16 // cycle so it will remove '__block' and add __weak/__unsafe_unretained.
17 // e.g
18 //
19 //   __block Foo *x;
20 //   bar(^ { [x cake]; });
21 // ---->
22 //   __weak Foo *x;
23 //   bar(^ { [x cake]; });
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #include "Transforms.h"
28 #include "Internals.h"
29 #include "clang/AST/ASTContext.h"
30 #include "clang/AST/Attr.h"
31 #include "clang/Basic/SourceManager.h"
32 
33 using namespace clang;
34 using namespace arcmt;
35 using namespace trans;
36 
37 namespace {
38 
39 class RootBlockObjCVarRewriter :
40                           public RecursiveASTVisitor<RootBlockObjCVarRewriter> {
41   llvm::DenseSet<VarDecl *> &VarsToChange;
42 
43   class BlockVarChecker : public RecursiveASTVisitor<BlockVarChecker> {
44     VarDecl *Var;
45 
46     typedef RecursiveASTVisitor<BlockVarChecker> base;
47   public:
48     BlockVarChecker(VarDecl *var) : Var(var) { }
49 
50     bool TraverseImplicitCastExpr(ImplicitCastExpr *castE) {
51       if (DeclRefExpr *
52             ref = dyn_cast<DeclRefExpr>(castE->getSubExpr())) {
53         if (ref->getDecl() == Var) {
54           if (castE->getCastKind() == CK_LValueToRValue)
55             return true; // Using the value of the variable.
56           if (castE->getCastKind() == CK_NoOp && castE->isLValue() &&
57               Var->getASTContext().getLangOpts().CPlusPlus)
58             return true; // Binding to const C++ reference.
59         }
60       }
61 
62       return base::TraverseImplicitCastExpr(castE);
63     }
64 
65     bool VisitDeclRefExpr(DeclRefExpr *E) {
66       if (E->getDecl() == Var)
67         return false; // The reference of the variable, and not just its value,
68                       //  is needed.
69       return true;
70     }
71   };
72 
73 public:
74   RootBlockObjCVarRewriter(llvm::DenseSet<VarDecl *> &VarsToChange)
75     : VarsToChange(VarsToChange) { }
76 
77   bool VisitBlockDecl(BlockDecl *block) {
78     SmallVector<VarDecl *, 4> BlockVars;
79 
80     for (const auto &I : block->captures()) {
81       VarDecl *var = I.getVariable();
82       if (I.isByRef() &&
83           var->getType()->isObjCObjectPointerType() &&
84           isImplicitStrong(var->getType())) {
85         BlockVars.push_back(var);
86       }
87     }
88 
89     for (unsigned i = 0, e = BlockVars.size(); i != e; ++i) {
90       VarDecl *var = BlockVars[i];
91 
92       BlockVarChecker checker(var);
93       bool onlyValueOfVarIsNeeded = checker.TraverseStmt(block->getBody());
94       if (onlyValueOfVarIsNeeded)
95         VarsToChange.insert(var);
96       else
97         VarsToChange.erase(var);
98     }
99 
100     return true;
101   }
102 
103 private:
104   bool isImplicitStrong(QualType ty) {
105     if (isa<AttributedType>(ty.getTypePtr()))
106       return false;
107     return ty.getLocalQualifiers().getObjCLifetime() == Qualifiers::OCL_Strong;
108   }
109 };
110 
111 class BlockObjCVarRewriter : public RecursiveASTVisitor<BlockObjCVarRewriter> {
112   llvm::DenseSet<VarDecl *> &VarsToChange;
113 
114 public:
115   BlockObjCVarRewriter(llvm::DenseSet<VarDecl *> &VarsToChange)
116     : VarsToChange(VarsToChange) { }
117 
118   bool TraverseBlockDecl(BlockDecl *block) {
119     RootBlockObjCVarRewriter(VarsToChange).TraverseDecl(block);
120     return true;
121   }
122 };
123 
124 } // anonymous namespace
125 
126 void BlockObjCVariableTraverser::traverseBody(BodyContext &BodyCtx) {
127   MigrationPass &Pass = BodyCtx.getMigrationContext().Pass;
128   llvm::DenseSet<VarDecl *> VarsToChange;
129 
130   BlockObjCVarRewriter trans(VarsToChange);
131   trans.TraverseStmt(BodyCtx.getTopStmt());
132 
133   for (llvm::DenseSet<VarDecl *>::iterator
134          I = VarsToChange.begin(), E = VarsToChange.end(); I != E; ++I) {
135     VarDecl *var = *I;
136     BlocksAttr *attr = var->getAttr<BlocksAttr>();
137     if(!attr)
138       continue;
139     bool useWeak = canApplyWeak(Pass.Ctx, var->getType());
140     SourceManager &SM = Pass.Ctx.getSourceManager();
141     Transaction Trans(Pass.TA);
142     Pass.TA.replaceText(SM.getExpansionLoc(attr->getLocation()),
143                         "__block",
144                         useWeak ? "__weak" : "__unsafe_unretained");
145   }
146 }
147