1 //===- AMDGPURewriteOutArgumentsPass.cpp - Create struct returns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This pass attempts to replace out argument usage with a return of a
10 /// struct.
11 ///
12 /// We can support returning a lot of values directly in registers, but
13 /// idiomatic C code frequently uses a pointer argument to return a second value
14 /// rather than returning a struct by value. GPU stack access is also quite
15 /// painful, so we want to avoid that if possible. Passing a stack object
16 /// pointer to a function also requires an additional address expansion code
17 /// sequence to convert the pointer to be relative to the kernel's scratch wave
18 /// offset register since the callee doesn't know what stack frame the incoming
19 /// pointer is relative to.
20 ///
21 /// The goal is to try rewriting code that looks like this:
22 ///
23 ///  int foo(int a, int b, int* out) {
24 ///     *out = bar();
25 ///     return a + b;
26 /// }
27 ///
28 /// into something like this:
29 ///
30 ///  std::pair<int, int> foo(int a, int b) {
31 ///     return std::make_pair(a + b, bar());
32 /// }
33 ///
34 /// Typically the incoming pointer is a simple alloca for a temporary variable
35 /// to use the API, which if replaced with a struct return will be easily SROA'd
36 /// out when the stub function we create is inlined
37 ///
38 /// This pass introduces the struct return, but leaves the unused pointer
39 /// arguments and introduces a new stub function calling the struct returning
40 /// body. DeadArgumentElimination should be run after this to clean these up.
41 //
42 //===----------------------------------------------------------------------===//
43 
44 #include "AMDGPU.h"
45 #include "Utils/AMDGPUBaseInfo.h"
46 #include "llvm/ADT/SmallSet.h"
47 #include "llvm/ADT/Statistic.h"
48 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
49 #include "llvm/IR/IRBuilder.h"
50 #include "llvm/IR/Instructions.h"
51 #include "llvm/InitializePasses.h"
52 #include "llvm/Pass.h"
53 #include "llvm/Support/CommandLine.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/raw_ostream.h"
56 
57 #define DEBUG_TYPE "amdgpu-rewrite-out-arguments"
58 
59 using namespace llvm;
60 
61 static cl::opt<bool> AnyAddressSpace(
62   "amdgpu-any-address-space-out-arguments",
63   cl::desc("Replace pointer out arguments with "
64            "struct returns for non-private address space"),
65   cl::Hidden,
66   cl::init(false));
67 
68 static cl::opt<unsigned> MaxNumRetRegs(
69   "amdgpu-max-return-arg-num-regs",
70   cl::desc("Approximately limit number of return registers for replacing out arguments"),
71   cl::Hidden,
72   cl::init(16));
73 
74 STATISTIC(NumOutArgumentsReplaced,
75           "Number out arguments moved to struct return values");
76 STATISTIC(NumOutArgumentFunctionsReplaced,
77           "Number of functions with out arguments moved to struct return values");
78 
79 namespace {
80 
81 class AMDGPURewriteOutArguments : public FunctionPass {
82 private:
83   const DataLayout *DL = nullptr;
84   MemoryDependenceResults *MDA = nullptr;
85 
86   bool checkArgumentUses(Value &Arg) const;
87   bool isOutArgumentCandidate(Argument &Arg) const;
88 
89 #ifndef NDEBUG
90   bool isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const;
91 #endif
92 
93 public:
94   static char ID;
95 
96   AMDGPURewriteOutArguments() : FunctionPass(ID) {}
97 
98   void getAnalysisUsage(AnalysisUsage &AU) const override {
99     AU.addRequired<MemoryDependenceWrapperPass>();
100     FunctionPass::getAnalysisUsage(AU);
101   }
102 
103   bool doInitialization(Module &M) override;
104   bool runOnFunction(Function &F) override;
105 };
106 
107 } // end anonymous namespace
108 
109 INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments, DEBUG_TYPE,
110                       "AMDGPU Rewrite Out Arguments", false, false)
111 INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
112 INITIALIZE_PASS_END(AMDGPURewriteOutArguments, DEBUG_TYPE,
113                     "AMDGPU Rewrite Out Arguments", false, false)
114 
115 char AMDGPURewriteOutArguments::ID = 0;
116 
117 bool AMDGPURewriteOutArguments::checkArgumentUses(Value &Arg) const {
118   const int MaxUses = 10;
119   int UseCount = 0;
120 
121   for (Use &U : Arg.uses()) {
122     StoreInst *SI = dyn_cast<StoreInst>(U.getUser());
123     if (UseCount > MaxUses)
124       return false;
125 
126     if (!SI) {
127       auto *BCI = dyn_cast<BitCastInst>(U.getUser());
128       if (!BCI || !BCI->hasOneUse())
129         return false;
130 
131       // We don't handle multiple stores currently, so stores to aggregate
132       // pointers aren't worth the trouble since they are canonically split up.
133       Type *DestEltTy = BCI->getType()->getPointerElementType();
134       if (DestEltTy->isAggregateType())
135         return false;
136 
137       // We could handle these if we had a convenient way to bitcast between
138       // them.
139       Type *SrcEltTy = Arg.getType()->getPointerElementType();
140       if (SrcEltTy->isArrayTy())
141         return false;
142 
143       // Special case handle structs with single members. It is useful to handle
144       // some casts between structs and non-structs, but we can't bitcast
145       // directly between them.  directly bitcast between them.  Blender uses
146       // some casts that look like { <3 x float> }* to <4 x float>*
147       if ((SrcEltTy->isStructTy() && (SrcEltTy->getStructNumElements() != 1)))
148         return false;
149 
150       // Clang emits OpenCL 3-vector type accesses with a bitcast to the
151       // equivalent 4-element vector and accesses that, and we're looking for
152       // this pointer cast.
153       if (DL->getTypeAllocSize(SrcEltTy) != DL->getTypeAllocSize(DestEltTy))
154         return false;
155 
156       return checkArgumentUses(*BCI);
157     }
158 
159     if (!SI->isSimple() ||
160         U.getOperandNo() != StoreInst::getPointerOperandIndex())
161       return false;
162 
163     ++UseCount;
164   }
165 
166   // Skip unused arguments.
167   return UseCount > 0;
168 }
169 
170 bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const {
171   const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
172   PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
173 
174   // TODO: It might be useful for any out arguments, not just privates.
175   if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
176                  !AnyAddressSpace) ||
177       Arg.hasByValAttr() || Arg.hasStructRetAttr() ||
178       DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) {
179     return false;
180   }
181 
182   return checkArgumentUses(Arg);
183 }
184 
185 bool AMDGPURewriteOutArguments::doInitialization(Module &M) {
186   DL = &M.getDataLayout();
187   return false;
188 }
189 
190 #ifndef NDEBUG
191 bool AMDGPURewriteOutArguments::isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const {
192   auto *VT0 = dyn_cast<FixedVectorType>(Ty0);
193   auto *VT1 = dyn_cast<FixedVectorType>(Ty1);
194   if (!VT0 || !VT1)
195     return false;
196 
197   if (VT0->getNumElements() != 3 ||
198       VT1->getNumElements() != 4)
199     return false;
200 
201   return DL->getTypeSizeInBits(VT0->getElementType()) ==
202          DL->getTypeSizeInBits(VT1->getElementType());
203 }
204 #endif
205 
206 bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {
207   if (skipFunction(F))
208     return false;
209 
210   // TODO: Could probably handle variadic functions.
211   if (F.isVarArg() || F.hasStructRetAttr() ||
212       AMDGPU::isEntryFunctionCC(F.getCallingConv()))
213     return false;
214 
215   MDA = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
216 
217   unsigned ReturnNumRegs = 0;
218   SmallSet<int, 4> OutArgIndexes;
219   SmallVector<Type *, 4> ReturnTypes;
220   Type *RetTy = F.getReturnType();
221   if (!RetTy->isVoidTy()) {
222     ReturnNumRegs = DL->getTypeStoreSize(RetTy) / 4;
223 
224     if (ReturnNumRegs >= MaxNumRetRegs)
225       return false;
226 
227     ReturnTypes.push_back(RetTy);
228   }
229 
230   SmallVector<Argument *, 4> OutArgs;
231   for (Argument &Arg : F.args()) {
232     if (isOutArgumentCandidate(Arg)) {
233       LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
234                         << " in function " << F.getName() << '\n');
235       OutArgs.push_back(&Arg);
236     }
237   }
238 
239   if (OutArgs.empty())
240     return false;
241 
242   using ReplacementVec = SmallVector<std::pair<Argument *, Value *>, 4>;
243 
244   DenseMap<ReturnInst *, ReplacementVec> Replacements;
245 
246   SmallVector<ReturnInst *, 4> Returns;
247   for (BasicBlock &BB : F) {
248     if (ReturnInst *RI = dyn_cast<ReturnInst>(&BB.back()))
249       Returns.push_back(RI);
250   }
251 
252   if (Returns.empty())
253     return false;
254 
255   bool Changing;
256 
257   do {
258     Changing = false;
259 
260     // Keep retrying if we are able to successfully eliminate an argument. This
261     // helps with cases with multiple arguments which may alias, such as in a
262     // sincos implemntation. If we have 2 stores to arguments, on the first
263     // attempt the MDA query will succeed for the second store but not the
264     // first. On the second iteration we've removed that out clobbering argument
265     // (by effectively moving it into another function) and will find the second
266     // argument is OK to move.
267     for (Argument *OutArg : OutArgs) {
268       bool ThisReplaceable = true;
269       SmallVector<std::pair<ReturnInst *, StoreInst *>, 4> ReplaceableStores;
270 
271       Type *ArgTy = OutArg->getType()->getPointerElementType();
272 
273       // Skip this argument if converting it will push us over the register
274       // count to return limit.
275 
276       // TODO: This is an approximation. When legalized this could be more. We
277       // can ask TLI for exactly how many.
278       unsigned ArgNumRegs = DL->getTypeStoreSize(ArgTy) / 4;
279       if (ArgNumRegs + ReturnNumRegs > MaxNumRetRegs)
280         continue;
281 
282       // An argument is convertible only if all exit blocks are able to replace
283       // it.
284       for (ReturnInst *RI : Returns) {
285         BasicBlock *BB = RI->getParent();
286 
287         MemDepResult Q = MDA->getPointerDependencyFrom(
288             MemoryLocation::getBeforeOrAfter(OutArg), true, BB->end(), BB, RI);
289         StoreInst *SI = nullptr;
290         if (Q.isDef())
291           SI = dyn_cast<StoreInst>(Q.getInst());
292 
293         if (SI) {
294           LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI << '\n');
295           ReplaceableStores.emplace_back(RI, SI);
296         } else {
297           ThisReplaceable = false;
298           break;
299         }
300       }
301 
302       if (!ThisReplaceable)
303         continue; // Try the next argument candidate.
304 
305       for (std::pair<ReturnInst *, StoreInst *> Store : ReplaceableStores) {
306         Value *ReplVal = Store.second->getValueOperand();
307 
308         auto &ValVec = Replacements[Store.first];
309         if (llvm::any_of(ValVec,
310                          [OutArg](const std::pair<Argument *, Value *> &Entry) {
311                            return Entry.first == OutArg;
312                          })) {
313           LLVM_DEBUG(dbgs()
314                      << "Saw multiple out arg stores" << *OutArg << '\n');
315           // It is possible to see stores to the same argument multiple times,
316           // but we expect these would have been optimized out already.
317           ThisReplaceable = false;
318           break;
319         }
320 
321         ValVec.emplace_back(OutArg, ReplVal);
322         Store.second->eraseFromParent();
323       }
324 
325       if (ThisReplaceable) {
326         ReturnTypes.push_back(ArgTy);
327         OutArgIndexes.insert(OutArg->getArgNo());
328         ++NumOutArgumentsReplaced;
329         Changing = true;
330       }
331     }
332   } while (Changing);
333 
334   if (Replacements.empty())
335     return false;
336 
337   LLVMContext &Ctx = F.getParent()->getContext();
338   StructType *NewRetTy = StructType::create(Ctx, ReturnTypes, F.getName());
339 
340   FunctionType *NewFuncTy = FunctionType::get(NewRetTy,
341                                               F.getFunctionType()->params(),
342                                               F.isVarArg());
343 
344   LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy << '\n');
345 
346   Function *NewFunc = Function::Create(NewFuncTy, Function::PrivateLinkage,
347                                        F.getName() + ".body");
348   F.getParent()->getFunctionList().insert(F.getIterator(), NewFunc);
349   NewFunc->copyAttributesFrom(&F);
350   NewFunc->setComdat(F.getComdat());
351 
352   // We want to preserve the function and param attributes, but need to strip
353   // off any return attributes, e.g. zeroext doesn't make sense with a struct.
354   NewFunc->stealArgumentListFrom(F);
355 
356   AttrBuilder RetAttrs;
357   RetAttrs.addAttribute(Attribute::SExt);
358   RetAttrs.addAttribute(Attribute::ZExt);
359   RetAttrs.addAttribute(Attribute::NoAlias);
360   NewFunc->removeAttributes(AttributeList::ReturnIndex, RetAttrs);
361   // TODO: How to preserve metadata?
362 
363   // Move the body of the function into the new rewritten function, and replace
364   // this function with a stub.
365   NewFunc->getBasicBlockList().splice(NewFunc->begin(), F.getBasicBlockList());
366 
367   for (std::pair<ReturnInst *, ReplacementVec> &Replacement : Replacements) {
368     ReturnInst *RI = Replacement.first;
369     IRBuilder<> B(RI);
370     B.SetCurrentDebugLocation(RI->getDebugLoc());
371 
372     int RetIdx = 0;
373     Value *NewRetVal = UndefValue::get(NewRetTy);
374 
375     Value *RetVal = RI->getReturnValue();
376     if (RetVal)
377       NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);
378 
379     for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second) {
380       Argument *Arg = ReturnPoint.first;
381       Value *Val = ReturnPoint.second;
382       Type *EltTy = Arg->getType()->getPointerElementType();
383       if (Val->getType() != EltTy) {
384         Type *EffectiveEltTy = EltTy;
385         if (StructType *CT = dyn_cast<StructType>(EltTy)) {
386           assert(CT->getNumElements() == 1);
387           EffectiveEltTy = CT->getElementType(0);
388         }
389 
390         if (DL->getTypeSizeInBits(EffectiveEltTy) !=
391             DL->getTypeSizeInBits(Val->getType())) {
392           assert(isVec3ToVec4Shuffle(EffectiveEltTy, Val->getType()));
393           Val = B.CreateShuffleVector(Val, ArrayRef<int>{0, 1, 2});
394         }
395 
396         Val = B.CreateBitCast(Val, EffectiveEltTy);
397 
398         // Re-create single element composite.
399         if (EltTy != EffectiveEltTy)
400           Val = B.CreateInsertValue(UndefValue::get(EltTy), Val, 0);
401       }
402 
403       NewRetVal = B.CreateInsertValue(NewRetVal, Val, RetIdx++);
404     }
405 
406     if (RetVal)
407       RI->setOperand(0, NewRetVal);
408     else {
409       B.CreateRet(NewRetVal);
410       RI->eraseFromParent();
411     }
412   }
413 
414   SmallVector<Value *, 16> StubCallArgs;
415   for (Argument &Arg : F.args()) {
416     if (OutArgIndexes.count(Arg.getArgNo())) {
417       // It's easier to preserve the type of the argument list. We rely on
418       // DeadArgumentElimination to take care of these.
419       StubCallArgs.push_back(UndefValue::get(Arg.getType()));
420     } else {
421       StubCallArgs.push_back(&Arg);
422     }
423   }
424 
425   BasicBlock *StubBB = BasicBlock::Create(Ctx, "", &F);
426   IRBuilder<> B(StubBB);
427   CallInst *StubCall = B.CreateCall(NewFunc, StubCallArgs);
428 
429   int RetIdx = RetTy->isVoidTy() ? 0 : 1;
430   for (Argument &Arg : F.args()) {
431     if (!OutArgIndexes.count(Arg.getArgNo()))
432       continue;
433 
434     PointerType *ArgType = cast<PointerType>(Arg.getType());
435 
436     auto *EltTy = ArgType->getElementType();
437     const auto Align =
438         DL->getValueOrABITypeAlignment(Arg.getParamAlign(), EltTy);
439 
440     Value *Val = B.CreateExtractValue(StubCall, RetIdx++);
441     Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace());
442 
443     // We can peek through bitcasts, so the type may not match.
444     Value *PtrVal = B.CreateBitCast(&Arg, PtrTy);
445 
446     B.CreateAlignedStore(Val, PtrVal, Align);
447   }
448 
449   if (!RetTy->isVoidTy()) {
450     B.CreateRet(B.CreateExtractValue(StubCall, 0));
451   } else {
452     B.CreateRetVoid();
453   }
454 
455   // The function is now a stub we want to inline.
456   F.addFnAttr(Attribute::AlwaysInline);
457 
458   ++NumOutArgumentFunctionsReplaced;
459   return true;
460 }
461 
462 FunctionPass *llvm::createAMDGPURewriteOutArgumentsPass() {
463   return new AMDGPURewriteOutArguments();
464 }
465