1 //===-- AMDGPUPromoteKernelArguments.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This pass recursively promotes generic pointer arguments of a kernel
10 /// into the global address space.
11 ///
12 /// The pass walks kernel's pointer arguments, then loads from them. If a loaded
13 /// value is a pointer and loaded pointer is unmodified in the kernel before the
14 /// load, then promote loaded pointer to global. Then recursively continue.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "AMDGPU.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Analysis/MemorySSA.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/InitializePasses.h"
23 
24 #define DEBUG_TYPE "amdgpu-promote-kernel-arguments"
25 
26 using namespace llvm;
27 
28 namespace {
29 
30 class AMDGPUPromoteKernelArguments : public FunctionPass {
31   MemorySSA *MSSA;
32 
33   Instruction *ArgCastInsertPt;
34 
35   SmallVector<Value *> Ptrs;
36 
37   void enqueueUsers(Value *Ptr);
38 
39   bool promotePointer(Value *Ptr);
40 
41 public:
42   static char ID;
43 
44   AMDGPUPromoteKernelArguments() : FunctionPass(ID) {}
45 
46   bool run(Function &F, MemorySSA &MSSA);
47 
48   bool runOnFunction(Function &F) override;
49 
50   void getAnalysisUsage(AnalysisUsage &AU) const override {
51     AU.addRequired<MemorySSAWrapperPass>();
52     AU.setPreservesAll();
53   }
54 };
55 
56 } // end anonymous namespace
57 
58 void AMDGPUPromoteKernelArguments::enqueueUsers(Value *Ptr) {
59   SmallVector<User *> PtrUsers(Ptr->users());
60 
61   while (!PtrUsers.empty()) {
62     Instruction *U = dyn_cast<Instruction>(PtrUsers.pop_back_val());
63     if (!U)
64       continue;
65 
66     switch (U->getOpcode()) {
67     default:
68       break;
69     case Instruction::Load: {
70       LoadInst *LD = cast<LoadInst>(U);
71       PointerType *PT = dyn_cast<PointerType>(LD->getType());
72       if (!PT ||
73           (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS &&
74            PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS &&
75            PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS) ||
76           LD->getPointerOperand()->stripInBoundsOffsets() != Ptr)
77         break;
78       const MemoryAccess *MA = MSSA->getWalker()->getClobberingMemoryAccess(LD);
79       // TODO: This load poprobably can be promoted to constant address space.
80       if (MSSA->isLiveOnEntryDef(MA))
81         Ptrs.push_back(LD);
82       break;
83     }
84     case Instruction::GetElementPtr:
85     case Instruction::AddrSpaceCast:
86     case Instruction::BitCast:
87       if (U->getOperand(0)->stripInBoundsOffsets() == Ptr)
88         PtrUsers.append(U->user_begin(), U->user_end());
89       break;
90     }
91   }
92 }
93 
94 bool AMDGPUPromoteKernelArguments::promotePointer(Value *Ptr) {
95   enqueueUsers(Ptr);
96 
97   PointerType *PT = cast<PointerType>(Ptr->getType());
98   if (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS)
99     return false;
100 
101   bool IsArg = isa<Argument>(Ptr);
102   IRBuilder<> B(IsArg ? ArgCastInsertPt
103                       : &*std::next(cast<Instruction>(Ptr)->getIterator()));
104 
105   // Cast pointer to global address space and back to flat and let
106   // Infer Address Spaces pass to do all necessary rewriting.
107   PointerType *NewPT =
108       PointerType::getWithSamePointeeType(PT, AMDGPUAS::GLOBAL_ADDRESS);
109   Value *Cast =
110       B.CreateAddrSpaceCast(Ptr, NewPT, Twine(Ptr->getName(), ".global"));
111   Value *CastBack =
112       B.CreateAddrSpaceCast(Cast, PT, Twine(Ptr->getName(), ".flat"));
113   Ptr->replaceUsesWithIf(CastBack,
114                          [Cast](Use &U) { return U.getUser() != Cast; });
115 
116   return true;
117 }
118 
119 // skip allocas
120 static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
121   BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
122   for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
123     AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
124 
125     // If this is a dynamic alloca, the value may depend on the loaded kernargs,
126     // so loads will need to be inserted before it.
127     if (!AI || !AI->isStaticAlloca())
128       break;
129   }
130 
131   return InsPt;
132 }
133 
134 bool AMDGPUPromoteKernelArguments::run(Function &F, MemorySSA &MSSA) {
135   if (skipFunction(F))
136     return false;
137 
138   CallingConv::ID CC = F.getCallingConv();
139   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
140     return false;
141 
142   ArgCastInsertPt = &*getInsertPt(*F.begin());
143   this->MSSA = &MSSA;
144 
145   for (Argument &Arg : F.args()) {
146     if (Arg.use_empty())
147       continue;
148 
149     PointerType *PT = dyn_cast<PointerType>(Arg.getType());
150     if (!PT || (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS &&
151                 PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS &&
152                 PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS))
153       continue;
154 
155     Ptrs.push_back(&Arg);
156   }
157 
158   bool Changed = false;
159   while (!Ptrs.empty()) {
160     Value *Ptr = Ptrs.pop_back_val();
161     Changed |= promotePointer(Ptr);
162   }
163 
164   return Changed;
165 }
166 
167 bool AMDGPUPromoteKernelArguments::runOnFunction(Function &F) {
168   MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA();
169   return run(F, MSSA);
170 }
171 
172 INITIALIZE_PASS_BEGIN(AMDGPUPromoteKernelArguments, DEBUG_TYPE,
173                       "AMDGPU Promote Kernel Arguments", false, false)
174 INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
175 INITIALIZE_PASS_END(AMDGPUPromoteKernelArguments, DEBUG_TYPE,
176                     "AMDGPU Promote Kernel Arguments", false, false)
177 
178 char AMDGPUPromoteKernelArguments::ID = 0;
179 
180 FunctionPass *llvm::createAMDGPUPromoteKernelArgumentsPass() {
181   return new AMDGPUPromoteKernelArguments();
182 }
183 
184 PreservedAnalyses
185 AMDGPUPromoteKernelArgumentsPass::run(Function &F,
186                                       FunctionAnalysisManager &AM) {
187   MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA();
188   if (AMDGPUPromoteKernelArguments().run(F, MSSA)) {
189     PreservedAnalyses PA;
190     PA.preserveSet<CFGAnalyses>();
191     PA.preserve<MemorySSAAnalysis>();
192     return PA;
193   }
194   return PreservedAnalyses::all();
195 }
196