1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This pass does attempts to make use of reqd_work_group_size metadata
10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11 /// get_local_size-like functions.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "Utils/AMDGPUBaseInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicsAMDGPU.h"
25 #include "llvm/IR/PatternMatch.h"
26 #include "llvm/Pass.h"
27 
28 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
29 
30 using namespace llvm;
31 
32 namespace {
33 
34 // Field offsets in hsa_kernel_dispatch_packet_t.
35 enum DispatchPackedOffsets {
36   WORKGROUP_SIZE_X = 4,
37   WORKGROUP_SIZE_Y = 6,
38   WORKGROUP_SIZE_Z = 8,
39 
40   GRID_SIZE_X = 12,
41   GRID_SIZE_Y = 16,
42   GRID_SIZE_Z = 20
43 };
44 
45 // Field offsets to implicit kernel argument pointer.
46 enum ImplicitArgOffsets {
47   HIDDEN_BLOCK_COUNT_X = 0,
48   HIDDEN_BLOCK_COUNT_Y = 4,
49   HIDDEN_BLOCK_COUNT_Z = 8,
50 
51   HIDDEN_GROUP_SIZE_X = 12,
52   HIDDEN_GROUP_SIZE_Y = 14,
53   HIDDEN_GROUP_SIZE_Z = 16,
54 
55   HIDDEN_REMAINDER_X = 18,
56   HIDDEN_REMAINDER_Y = 20,
57   HIDDEN_REMAINDER_Z = 22,
58 };
59 
60 class AMDGPULowerKernelAttributes : public ModulePass {
61 public:
62   static char ID;
63 
64   AMDGPULowerKernelAttributes() : ModulePass(ID) {}
65 
66   bool runOnModule(Module &M) override;
67 
68   StringRef getPassName() const override {
69     return "AMDGPU Kernel Attributes";
70   }
71 
72   void getAnalysisUsage(AnalysisUsage &AU) const override {
73     AU.setPreservesAll();
74  }
75 };
76 
77 Function *getBasePtrIntrinsic(Module &M, bool IsV5OrAbove) {
78   auto IntrinsicId = IsV5OrAbove ? Intrinsic::amdgcn_implicitarg_ptr
79                                  : Intrinsic::amdgcn_dispatch_ptr;
80   StringRef Name = Intrinsic::getName(IntrinsicId);
81   return M.getFunction(Name);
82 }
83 
84 } // end anonymous namespace
85 
86 static bool processUse(CallInst *CI, bool IsV5OrAbove) {
87   Function *F = CI->getParent()->getParent();
88 
89   auto MD = F->getMetadata("reqd_work_group_size");
90   const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
91 
92   const bool HasUniformWorkGroupSize =
93     F->getFnAttribute("uniform-work-group-size").getValueAsBool();
94 
95   if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
96     return false;
97 
98   Value *BlockCounts[3] = {nullptr, nullptr, nullptr};
99   Value *GroupSizes[3]  = {nullptr, nullptr, nullptr};
100   Value *Remainders[3]  = {nullptr, nullptr, nullptr};
101   Value *GridSizes[3]   = {nullptr, nullptr, nullptr};
102 
103   const DataLayout &DL = F->getParent()->getDataLayout();
104 
105   // We expect to see several GEP users, casted to the appropriate type and
106   // loaded.
107   for (User *U : CI->users()) {
108     if (!U->hasOneUse())
109       continue;
110 
111     int64_t Offset = 0;
112     auto *Load = dyn_cast<LoadInst>(U); // Load from ImplicitArgPtr/DispatchPtr?
113     auto *BCI = dyn_cast<BitCastInst>(U);
114     if (!Load && !BCI) {
115       if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
116         continue;
117       Load = dyn_cast<LoadInst>(*U->user_begin()); // Load from GEP?
118       BCI = dyn_cast<BitCastInst>(*U->user_begin());
119     }
120 
121     if (BCI) {
122       if (!BCI->hasOneUse())
123         continue;
124       Load = dyn_cast<LoadInst>(*BCI->user_begin()); // Load from BCI?
125     }
126 
127     if (!Load || !Load->isSimple())
128       continue;
129 
130     unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
131 
132     // TODO: Handle merged loads.
133     if (IsV5OrAbove) { // Base is ImplicitArgPtr.
134       switch (Offset) {
135       case HIDDEN_BLOCK_COUNT_X:
136         if (LoadSize == 4)
137           BlockCounts[0] = Load;
138         break;
139       case HIDDEN_BLOCK_COUNT_Y:
140         if (LoadSize == 4)
141           BlockCounts[1] = Load;
142         break;
143       case HIDDEN_BLOCK_COUNT_Z:
144         if (LoadSize == 4)
145           BlockCounts[2] = Load;
146         break;
147       case HIDDEN_GROUP_SIZE_X:
148         if (LoadSize == 2)
149           GroupSizes[0] = Load;
150         break;
151       case HIDDEN_GROUP_SIZE_Y:
152         if (LoadSize == 2)
153           GroupSizes[1] = Load;
154         break;
155       case HIDDEN_GROUP_SIZE_Z:
156         if (LoadSize == 2)
157           GroupSizes[2] = Load;
158         break;
159       case HIDDEN_REMAINDER_X:
160         if (LoadSize == 2)
161           Remainders[0] = Load;
162         break;
163       case HIDDEN_REMAINDER_Y:
164         if (LoadSize == 2)
165           Remainders[1] = Load;
166         break;
167       case HIDDEN_REMAINDER_Z:
168         if (LoadSize == 2)
169           Remainders[2] = Load;
170         break;
171       default:
172         break;
173       }
174     } else { // Base is DispatchPtr.
175       switch (Offset) {
176       case WORKGROUP_SIZE_X:
177         if (LoadSize == 2)
178           GroupSizes[0] = Load;
179         break;
180       case WORKGROUP_SIZE_Y:
181         if (LoadSize == 2)
182           GroupSizes[1] = Load;
183         break;
184       case WORKGROUP_SIZE_Z:
185         if (LoadSize == 2)
186           GroupSizes[2] = Load;
187         break;
188       case GRID_SIZE_X:
189         if (LoadSize == 4)
190           GridSizes[0] = Load;
191         break;
192       case GRID_SIZE_Y:
193         if (LoadSize == 4)
194           GridSizes[1] = Load;
195         break;
196       case GRID_SIZE_Z:
197         if (LoadSize == 4)
198           GridSizes[2] = Load;
199         break;
200       default:
201         break;
202       }
203     }
204   }
205 
206   bool MadeChange = false;
207   if (IsV5OrAbove && HasUniformWorkGroupSize) {
208     // Under v5  __ockl_get_local_size returns the value computed by the expression:
209     //
210     //   workgroup_id < hidden_block_count ? hidden_group_size : hidden_remainder
211     //
212     // For functions with the attribute uniform-work-group-size=true. we can evaluate
213     // workgroup_id < hidden_block_count as true, and thus hidden_group_size is returned
214     // for __ockl_get_local_size.
215     for (int I = 0; I < 3; ++I) {
216       Value *BlockCount = BlockCounts[I];
217       if (!BlockCount)
218         continue;
219 
220       using namespace llvm::PatternMatch;
221       auto GroupIDIntrin =
222           I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
223                  : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
224                            : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
225 
226       for (User *ICmp : BlockCount->users()) {
227         ICmpInst::Predicate Pred;
228         if (match(ICmp, m_ICmp(Pred, GroupIDIntrin, m_Specific(BlockCount)))) {
229           if (Pred != ICmpInst::ICMP_ULT)
230             continue;
231           ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
232           MadeChange = true;
233         }
234       }
235     }
236 
237     // All remainders should be 0 with uniform work group size.
238     for (Value *Remainder : Remainders) {
239       if (!Remainder)
240         continue;
241       Remainder->replaceAllUsesWith(Constant::getNullValue(Remainder->getType()));
242       MadeChange = true;
243     }
244   } else if (HasUniformWorkGroupSize) { // Pre-V5.
245     // Pattern match the code used to handle partial workgroup dispatches in the
246     // library implementation of get_local_size, so the entire function can be
247     // constant folded with a known group size.
248     //
249     // uint r = grid_size - group_id * group_size;
250     // get_local_size = (r < group_size) ? r : group_size;
251     //
252     // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
253     // the grid_size is required to be a multiple of group_size). In this case:
254     //
255     // grid_size - (group_id * group_size) < group_size
256     // ->
257     // grid_size < group_size + (group_id * group_size)
258     //
259     // (grid_size / group_size) < 1 + group_id
260     //
261     // grid_size / group_size is at least 1, so we can conclude the select
262     // condition is false (except for group_id == 0, where the select result is
263     // the same).
264     for (int I = 0; I < 3; ++I) {
265       Value *GroupSize = GroupSizes[I];
266       Value *GridSize = GridSizes[I];
267       if (!GroupSize || !GridSize)
268         continue;
269 
270       using namespace llvm::PatternMatch;
271       auto GroupIDIntrin =
272           I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
273                  : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
274                            : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
275 
276       for (User *U : GroupSize->users()) {
277         auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
278         if (!ZextGroupSize)
279           continue;
280 
281         for (User *UMin : ZextGroupSize->users()) {
282           if (match(UMin,
283                     m_UMin(m_Sub(m_Specific(GridSize),
284                                  m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))),
285                            m_Specific(ZextGroupSize)))) {
286             if (HasReqdWorkGroupSize) {
287               ConstantInt *KnownSize
288                 = mdconst::extract<ConstantInt>(MD->getOperand(I));
289               UMin->replaceAllUsesWith(ConstantExpr::getIntegerCast(
290                   KnownSize, UMin->getType(), false));
291             } else {
292               UMin->replaceAllUsesWith(ZextGroupSize);
293             }
294 
295             MadeChange = true;
296           }
297         }
298       }
299     }
300   }
301 
302   // If reqd_work_group_size is set, we can replace work group size with it.
303   if (!HasReqdWorkGroupSize)
304     return MadeChange;
305 
306   for (int I = 0; I < 3; I++) {
307     Value *GroupSize = GroupSizes[I];
308     if (!GroupSize)
309       continue;
310 
311     ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
312     GroupSize->replaceAllUsesWith(
313         ConstantExpr::getIntegerCast(KnownSize, GroupSize->getType(), false));
314     MadeChange = true;
315   }
316 
317   return MadeChange;
318 }
319 
320 
321 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
322 // TargetPassConfig for subtarget.
323 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
324   bool MadeChange = false;
325   bool IsV5OrAbove = AMDGPU::getAmdhsaCodeObjectVersion() >= 5;
326   Function *BasePtr = getBasePtrIntrinsic(M, IsV5OrAbove);
327 
328   if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
329     return false;
330 
331   SmallPtrSet<Instruction *, 4> HandledUses;
332   for (auto *U : BasePtr->users()) {
333     CallInst *CI = cast<CallInst>(U);
334     if (HandledUses.insert(CI).second) {
335       if (processUse(CI, IsV5OrAbove))
336         MadeChange = true;
337     }
338   }
339 
340   return MadeChange;
341 }
342 
343 
344 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
345                       "AMDGPU Kernel Attributes", false, false)
346 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
347                     "AMDGPU Kernel Attributes", false, false)
348 
349 char AMDGPULowerKernelAttributes::ID = 0;
350 
351 ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
352   return new AMDGPULowerKernelAttributes();
353 }
354 
355 PreservedAnalyses
356 AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) {
357   bool IsV5OrAbove = AMDGPU::getAmdhsaCodeObjectVersion() >= 5;
358   Function *BasePtr = getBasePtrIntrinsic(*F.getParent(), IsV5OrAbove);
359 
360   if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
361     return PreservedAnalyses::all();
362 
363   for (Instruction &I : instructions(F)) {
364     if (CallInst *CI = dyn_cast<CallInst>(&I)) {
365       if (CI->getCalledFunction() == BasePtr)
366         processUse(CI, IsV5OrAbove);
367     }
368   }
369 
370   return PreservedAnalyses::all();
371 }
372