1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Take a module and rewrite:
10 // 1. `malloc` -> `polly_mallocManaged`
11 // 2. `free` -> `polly_freeManaged`
12 // 3. global arrays with initializers -> global arrays that are initialized
13 // with a constructor call to
14 // `polly_mallocManaged`.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "polly/CodeGen/IRBuilder.h"
19 #include "polly/CodeGen/PPCGCodeGeneration.h"
20 #include "polly/DependenceInfo.h"
21 #include "polly/LinkAllPasses.h"
22 #include "polly/Options.h"
23 #include "polly/ScopDetection.h"
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/Analysis/CaptureTracking.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Transforms/Utils/ModuleUtils.h"
28
29 using namespace llvm;
30 using namespace polly;
31
32 static cl::opt<bool> RewriteAllocas(
33 "polly-acc-rewrite-allocas",
34 cl::desc(
35 "Ask the managed memory rewriter to also rewrite alloca instructions"),
36 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
37
38 static cl::opt<bool> IgnoreLinkageForGlobals(
39 "polly-acc-rewrite-ignore-linkage-for-globals",
40 cl::desc(
41 "By default, we only rewrite globals with internal linkage. This flag "
42 "enables rewriting of globals regardless of linkage"),
43 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
44
45 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
46 namespace {
47
getOrCreatePollyMallocManaged(Module & M)48 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
49 const char *Name = "polly_mallocManaged";
50 Function *F = M.getFunction(Name);
51
52 // If F is not available, declare it.
53 if (!F) {
54 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
55 PollyIRBuilder Builder(M.getContext());
56 // TODO: How do I get `size_t`? I assume from DataLayout?
57 FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
58 {Builder.getInt64Ty()}, false);
59 F = Function::Create(Ty, Linkage, Name, &M);
60 }
61
62 return F;
63 }
64
getOrCreatePollyFreeManaged(Module & M)65 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
66 const char *Name = "polly_freeManaged";
67 Function *F = M.getFunction(Name);
68
69 // If F is not available, declare it.
70 if (!F) {
71 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
72 PollyIRBuilder Builder(M.getContext());
73 // TODO: How do I get `size_t`? I assume from DataLayout?
74 FunctionType *Ty =
75 FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
76 F = Function::Create(Ty, Linkage, Name, &M);
77 }
78
79 return F;
80 }
81
82 // Expand a constant expression `Cur`, which is used at instruction `Parent`
83 // at index `index`.
84 // Since a constant expression can expand to multiple instructions, store all
85 // the expands into a set called `Expands`.
86 // Note that this goes inorder on the constant expression tree.
87 // A * ((B * D) + C)
88 // will be processed with first A, then B * D, then B, then D, and then C.
89 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
90 // have something like this:
91 // *
92 // / \
93 // \ /
94 // (D)
95 //
96 // For the purposes of this expansion, we expand the two occurences of D
97 // separately. Therefore, we expand the DAG into the tree:
98 // *
99 // / \
100 // D D
101 // TODO: We don't _have_to do this, but this is the simplest solution.
102 // We can write a solution that keeps track of which constants have been
103 // already expanded.
expandConstantExpr(ConstantExpr * Cur,PollyIRBuilder & Builder,Instruction * Parent,int index,SmallPtrSet<Instruction *,4> & Expands)104 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
105 Instruction *Parent, int index,
106 SmallPtrSet<Instruction *, 4> &Expands) {
107 assert(Cur && "invalid constant expression passed");
108 Instruction *I = Cur->getAsInstruction();
109 assert(I && "unable to convert ConstantExpr to Instruction");
110
111 LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur
112 << ") in Instruction: (" << *I << ")\n";);
113
114 // Invalidate `Cur` so that no one after this point uses `Cur`. Rather,
115 // they should mutate `I`.
116 Cur = nullptr;
117
118 Expands.insert(I);
119 Parent->setOperand(index, I);
120
121 // The things that `Parent` uses (its operands) should be created
122 // before `Parent`.
123 Builder.SetInsertPoint(Parent);
124 Builder.Insert(I);
125
126 for (unsigned i = 0; i < I->getNumOperands(); i++) {
127 Value *Op = I->getOperand(i);
128 assert(isa<Constant>(Op) && "constant must have a constant operand");
129
130 if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
131 expandConstantExpr(CExprOp, Builder, I, i, Expands);
132 }
133 }
134
135 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
136 // `ConstantExpr`s that are used in the `Inst`.
137 // Note that `replaceAllUsesWith` is insufficient for this purpose because it
138 // does not rewrite values in `ConstantExpr`s.
rewriteOldValToNew(Instruction * Inst,Value * OldVal,Value * NewVal,PollyIRBuilder & Builder)139 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
140 PollyIRBuilder &Builder) {
141
142 // This contains a set of instructions in which OldVal must be replaced.
143 // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
144 // from `Inst`s arguments.
145 // We need to go through this process because `replaceAllUsesWith` does not
146 // actually edit `ConstantExpr`s.
147 SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
148
149 // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
150 for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
151 Value *Operand = Inst->getOperand(i);
152 if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
153 expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
154 }
155
156 // Now visit each instruction and use `replaceUsesOfWith`. We know that
157 // will work because `I` cannot have any `ConstantExpr` within it.
158 for (Instruction *I : InstsToVisit)
159 I->replaceUsesOfWith(OldVal, NewVal);
160 }
161
162 // Given a value `Current`, return all Instructions that may contain `Current`
163 // in an expression.
164 // We need this auxiliary function, because if we have a
165 // `Constant` that is a user of `V`, we need to recurse into the
166 // `Constant`s uses to gather the root instruciton.
getInstructionUsersOfValue(Value * V,SmallVector<Instruction *,4> & Owners)167 static void getInstructionUsersOfValue(Value *V,
168 SmallVector<Instruction *, 4> &Owners) {
169 if (auto *I = dyn_cast<Instruction>(V)) {
170 Owners.push_back(I);
171 } else {
172 // Anything that is a `User` must be a constant or an instruction.
173 auto *C = cast<Constant>(V);
174 for (Use &CUse : C->uses())
175 getInstructionUsersOfValue(CUse.getUser(), Owners);
176 }
177 }
178
179 static void
replaceGlobalArray(Module & M,const DataLayout & DL,GlobalVariable & Array,SmallPtrSet<GlobalVariable *,4> & ReplacedGlobals)180 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
181 SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
182 // We only want arrays.
183 ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType());
184 if (!ArrayTy)
185 return;
186 Type *ElemTy = ArrayTy->getElementType();
187 PointerType *ElemPtrTy = ElemTy->getPointerTo();
188
189 // We only wish to replace arrays that are visible in the module they
190 // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
191 // modules.
192 const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
193 Array.hasInternalLinkage() ||
194 IgnoreLinkageForGlobals;
195 if (!OnlyVisibleInsideModule) {
196 LLVM_DEBUG(
197 dbgs() << "Not rewriting (" << Array
198 << ") to managed memory "
199 "because it could be visible externally. To force rewrite, "
200 "use -polly-acc-rewrite-ignore-linkage-for-globals.\n");
201 return;
202 }
203
204 if (!Array.hasInitializer() ||
205 !isa<ConstantAggregateZero>(Array.getInitializer())) {
206 LLVM_DEBUG(dbgs() << "Not rewriting (" << Array
207 << ") to managed memory "
208 "because it has an initializer which is "
209 "not a zeroinitializer.\n");
210 return;
211 }
212
213 // At this point, we have committed to replacing this array.
214 ReplacedGlobals.insert(&Array);
215
216 std::string NewName = Array.getName().str();
217 NewName += ".toptr";
218 GlobalVariable *ReplacementToArr =
219 cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
220 ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
221
222 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
223 std::string FnName = Array.getName().str();
224 FnName += ".constructor";
225 PollyIRBuilder Builder(M.getContext());
226 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
227 const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
228 Function *F = Function::Create(Ty, Linkage, FnName, &M);
229 BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
230 Builder.SetInsertPoint(Start);
231
232 const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy);
233 Value *ArraySize = Builder.getInt64(ArraySizeInt);
234 ArraySize->setName("array.size");
235
236 Value *AllocatedMemRaw =
237 Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
238 Value *AllocatedMemTyped =
239 Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
240 Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
241 Builder.CreateRetVoid();
242
243 const int Priority = 0;
244 appendToGlobalCtors(M, F, Priority, ReplacementToArr);
245
246 SmallVector<Instruction *, 4> ArrayUserInstructions;
247 // Get all instructions that use array. We need to do this weird thing
248 // because `Constant`s that contain this array neeed to be expanded into
249 // instructions so that we can replace their parameters. `Constant`s cannot
250 // be edited easily, so we choose to convert all `Constant`s to
251 // `Instruction`s and handle all of the uses of `Array` uniformly.
252 for (Use &ArrayUse : Array.uses())
253 getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
254
255 for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
256
257 Builder.SetInsertPoint(UserOfArrayInst);
258 // <ty>** -> <ty>*
259 Value *ArrPtrLoaded =
260 Builder.CreateLoad(ElemPtrTy, ReplacementToArr, "arrptr.load");
261 // <ty>* -> [ty]*
262 Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
263 ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
264 rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
265 }
266 }
267
268 // We return all `allocas` that may need to be converted to a call to
269 // cudaMallocManaged.
getAllocasToBeManaged(Function & F,SmallSet<AllocaInst *,4> & Allocas)270 static void getAllocasToBeManaged(Function &F,
271 SmallSet<AllocaInst *, 4> &Allocas) {
272 for (BasicBlock &BB : F) {
273 for (Instruction &I : BB) {
274 auto *Alloca = dyn_cast<AllocaInst>(&I);
275 if (!Alloca)
276 continue;
277 LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: ");
278
279 if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
280 /* StoreCaptures */ true)) {
281 Allocas.insert(Alloca);
282 LLVM_DEBUG(dbgs() << "YES (captured).\n");
283 } else {
284 LLVM_DEBUG(dbgs() << "NO (not captured).\n");
285 }
286 }
287 }
288 }
289
rewriteAllocaAsManagedMemory(AllocaInst * Alloca,const DataLayout & DL)290 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
291 const DataLayout &DL) {
292 LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n");
293 Module *M = Alloca->getModule();
294 assert(M && "Alloca does not have a module");
295
296 PollyIRBuilder Builder(M->getContext());
297 Builder.SetInsertPoint(Alloca);
298
299 Function *MallocManagedFn =
300 getOrCreatePollyMallocManaged(*Alloca->getModule());
301 const uint64_t Size =
302 DL.getTypeAllocSize(Alloca->getType()->getElementType());
303 Value *SizeVal = Builder.getInt64(Size);
304 Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
305 Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
306
307 Function *F = Alloca->getFunction();
308 assert(F && "Alloca has invalid function");
309
310 Bitcasted->takeName(Alloca);
311 Alloca->replaceAllUsesWith(Bitcasted);
312 Alloca->eraseFromParent();
313
314 for (BasicBlock &BB : *F) {
315 ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
316 if (!Return)
317 continue;
318 Builder.SetInsertPoint(Return);
319
320 Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
321 Builder.CreateCall(FreeManagedFn, {RawManagedMem});
322 }
323 }
324
325 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`.
326 //
327 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function
328 // actually does replace it in `ConstantExpr`. The caveat is that if there is
329 // a use that is *outside* a function (say, at global declarations), we fail.
330 // So, this is meant to be used on values which we know will only be used
331 // within functions.
332 //
333 // This process works by looking through the uses of `Old`. If it finds a
334 // `ConstantExpr`, it recursively looks for the owning instruction.
335 // Then, it expands all the `ConstantExpr` to instructions and replaces
336 // `Old` with `New` in the expanded instructions.
replaceAllUsesAndConstantUses(Value * Old,Value * New,PollyIRBuilder & Builder)337 static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
338 PollyIRBuilder &Builder) {
339 SmallVector<Instruction *, 4> UserInstructions;
340 // Get all instructions that use array. We need to do this weird thing
341 // because `Constant`s that contain this array neeed to be expanded into
342 // instructions so that we can replace their parameters. `Constant`s cannot
343 // be edited easily, so we choose to convert all `Constant`s to
344 // `Instruction`s and handle all of the uses of `Array` uniformly.
345 for (Use &ArrayUse : Old->uses())
346 getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
347
348 for (Instruction *I : UserInstructions)
349 rewriteOldValToNew(I, Old, New, Builder);
350 }
351
352 class ManagedMemoryRewritePass : public ModulePass {
353 public:
354 static char ID;
355 GPUArch Architecture;
356 GPURuntime Runtime;
357
ManagedMemoryRewritePass()358 ManagedMemoryRewritePass() : ModulePass(ID) {}
runOnModule(Module & M)359 bool runOnModule(Module &M) override {
360 const DataLayout &DL = M.getDataLayout();
361
362 Function *Malloc = M.getFunction("malloc");
363
364 if (Malloc) {
365 PollyIRBuilder Builder(M.getContext());
366 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
367 assert(PollyMallocManaged && "unable to create polly_mallocManaged");
368
369 replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
370 Malloc->eraseFromParent();
371 }
372
373 Function *Free = M.getFunction("free");
374
375 if (Free) {
376 PollyIRBuilder Builder(M.getContext());
377 Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
378 assert(PollyFreeManaged && "unable to create polly_freeManaged");
379
380 replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
381 Free->eraseFromParent();
382 }
383
384 SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
385 for (GlobalVariable &Global : M.globals())
386 replaceGlobalArray(M, DL, Global, GlobalsToErase);
387 for (GlobalVariable *G : GlobalsToErase)
388 G->eraseFromParent();
389
390 // Rewrite allocas to cudaMallocs if we are asked to do so.
391 if (RewriteAllocas) {
392 SmallSet<AllocaInst *, 4> AllocasToBeManaged;
393 for (Function &F : M.functions())
394 getAllocasToBeManaged(F, AllocasToBeManaged);
395
396 for (AllocaInst *Alloca : AllocasToBeManaged)
397 rewriteAllocaAsManagedMemory(Alloca, DL);
398 }
399
400 return true;
401 }
402 };
403 } // namespace
404 char ManagedMemoryRewritePass::ID = 42;
405
createManagedMemoryRewritePassPass(GPUArch Arch,GPURuntime Runtime)406 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
407 GPURuntime Runtime) {
408 ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
409 pass->Runtime = Runtime;
410 pass->Architecture = Arch;
411 return pass;
412 }
413
414 INITIALIZE_PASS_BEGIN(
415 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
416 "Polly - Rewrite all allocations in heap & data section to managed memory",
417 false, false)
418 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
419 INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
420 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
421 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
422 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
423 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
424 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
425 INITIALIZE_PASS_END(
426 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
427 "Polly - Rewrite all allocations in heap & data section to managed memory",
428 false, false)
429