1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
11 /// sampler resource ID getter functions.
12 ///
13 /// Image attributes (size and format) are expected to be passed to the kernel
14 /// as kernel arguments immediately following the image argument itself,
15 /// therefore this pass adds image size and format arguments to the kernel
16 /// functions in the module. The kernel functions with image arguments are
17 /// re-created using the new signature. The new arguments are added to the
18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19 /// Note: this pass may invalidate pointers to functions.
20 ///
21 /// Resource IDs of read-only images, write-only images and samplers are
22 /// defined to be their index among the kernel arguments of the same
23 /// type and access qualifier.
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #include "R600.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/IR/Constants.h"
31 #include "llvm/IR/Function.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Metadata.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Transforms/Utils/Cloning.h"
36 
37 using namespace llvm;
38 
39 static StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
40 static StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
41 static StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
42 static StringRef GetSamplerResourceIDFunc =
43     "llvm.OpenCL.sampler.get.resource.id";
44 
45 static StringRef ImageSizeArgMDType =   "__llvm_image_size";
46 static StringRef ImageFormatArgMDType = "__llvm_image_format";
47 
48 static StringRef KernelsMDNodeName = "opencl.kernels";
49 static StringRef KernelArgMDNodeNames[] = {
50   "kernel_arg_addr_space",
51   "kernel_arg_access_qual",
52   "kernel_arg_type",
53   "kernel_arg_base_type",
54   "kernel_arg_type_qual"};
55 static const unsigned NumKernelArgMDNodes = 5;
56 
57 namespace {
58 
59 using MDVector = SmallVector<Metadata *, 8>;
60 struct KernelArgMD {
61   MDVector ArgVector[NumKernelArgMDNodes];
62 };
63 
64 } // end anonymous namespace
65 
66 static inline bool
67 IsImageType(StringRef TypeString) {
68   return TypeString == "image2d_t" || TypeString == "image3d_t";
69 }
70 
71 static inline bool
72 IsSamplerType(StringRef TypeString) {
73   return TypeString == "sampler_t";
74 }
75 
76 static Function *
77 GetFunctionFromMDNode(MDNode *Node) {
78   if (!Node)
79     return nullptr;
80 
81   size_t NumOps = Node->getNumOperands();
82   if (NumOps != NumKernelArgMDNodes + 1)
83     return nullptr;
84 
85   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
86   if (!F)
87     return nullptr;
88 
89   // Validation checks.
90   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
91   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
92     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
93     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
94       return nullptr;
95     if (!ArgNode->getOperand(0))
96       return nullptr;
97 
98     // FIXME: It should be possible to do image lowering when some metadata
99     // args missing or not in the expected order.
100     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
101     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
102       return nullptr;
103   }
104 
105   return F;
106 }
107 
108 static StringRef
109 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
110   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
111   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
112 }
113 
114 static StringRef
115 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
116   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
117   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
118 }
119 
120 static MDVector
121 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
122   MDVector Res;
123   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
124     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
125     Res.push_back(Node->getOperand(OpIdx));
126   }
127   return Res;
128 }
129 
130 static void
131 PushArgMD(KernelArgMD &MD, const MDVector &V) {
132   assert(V.size() == NumKernelArgMDNodes);
133   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
134     MD.ArgVector[i].push_back(V[i]);
135   }
136 }
137 
138 namespace {
139 
140 class R600OpenCLImageTypeLoweringPass : public ModulePass {
141   static char ID;
142 
143   LLVMContext *Context;
144   Type *Int32Type;
145   Type *ImageSizeType;
146   Type *ImageFormatType;
147   SmallVector<Instruction *, 4> InstsToErase;
148 
149   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
150                         Argument &ImageSizeArg,
151                         Argument &ImageFormatArg) {
152     bool Modified = false;
153 
154     for (auto &Use : ImageArg.uses()) {
155       auto Inst = dyn_cast<CallInst>(Use.getUser());
156       if (!Inst) {
157         continue;
158       }
159 
160       Function *F = Inst->getCalledFunction();
161       if (!F)
162         continue;
163 
164       Value *Replacement = nullptr;
165       StringRef Name = F->getName();
166       if (Name.startswith(GetImageResourceIDFunc)) {
167         Replacement = ConstantInt::get(Int32Type, ResourceID);
168       } else if (Name.startswith(GetImageSizeFunc)) {
169         Replacement = &ImageSizeArg;
170       } else if (Name.startswith(GetImageFormatFunc)) {
171         Replacement = &ImageFormatArg;
172       } else {
173         continue;
174       }
175 
176       Inst->replaceAllUsesWith(Replacement);
177       InstsToErase.push_back(Inst);
178       Modified = true;
179     }
180 
181     return Modified;
182   }
183 
184   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
185     bool Modified = false;
186 
187     for (const auto &Use : SamplerArg.uses()) {
188       auto Inst = dyn_cast<CallInst>(Use.getUser());
189       if (!Inst) {
190         continue;
191       }
192 
193       Function *F = Inst->getCalledFunction();
194       if (!F)
195         continue;
196 
197       Value *Replacement = nullptr;
198       StringRef Name = F->getName();
199       if (Name == GetSamplerResourceIDFunc) {
200         Replacement = ConstantInt::get(Int32Type, ResourceID);
201       } else {
202         continue;
203       }
204 
205       Inst->replaceAllUsesWith(Replacement);
206       InstsToErase.push_back(Inst);
207       Modified = true;
208     }
209 
210     return Modified;
211   }
212 
213   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
214     uint32_t NumReadOnlyImageArgs = 0;
215     uint32_t NumWriteOnlyImageArgs = 0;
216     uint32_t NumSamplerArgs = 0;
217 
218     bool Modified = false;
219     InstsToErase.clear();
220     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
221       Argument &Arg = *ArgI;
222       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
223 
224       // Handle image types.
225       if (IsImageType(Type)) {
226         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
227         uint32_t ResourceID;
228         if (AccessQual == "read_only") {
229           ResourceID = NumReadOnlyImageArgs++;
230         } else if (AccessQual == "write_only") {
231           ResourceID = NumWriteOnlyImageArgs++;
232         } else {
233           llvm_unreachable("Wrong image access qualifier.");
234         }
235 
236         Argument &SizeArg = *(++ArgI);
237         Argument &FormatArg = *(++ArgI);
238         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
239 
240       // Handle sampler type.
241       } else if (IsSamplerType(Type)) {
242         uint32_t ResourceID = NumSamplerArgs++;
243         Modified |= replaceSamplerUses(Arg, ResourceID);
244       }
245     }
246     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
247       InstsToErase[i]->eraseFromParent();
248     }
249 
250     return Modified;
251   }
252 
253   std::tuple<Function *, MDNode *>
254   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
255     bool Modified = false;
256 
257     FunctionType *FT = F->getFunctionType();
258     SmallVector<Type *, 8> ArgTypes;
259 
260     // Metadata operands for new MDNode.
261     KernelArgMD NewArgMDs;
262     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
263 
264     // Add implicit arguments to the signature.
265     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
266       ArgTypes.push_back(FT->getParamType(i));
267       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
268       PushArgMD(NewArgMDs, ArgMD);
269 
270       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
271         continue;
272 
273       // Add size implicit argument.
274       ArgTypes.push_back(ImageSizeType);
275       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
276       PushArgMD(NewArgMDs, ArgMD);
277 
278       // Add format implicit argument.
279       ArgTypes.push_back(ImageFormatType);
280       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
281       PushArgMD(NewArgMDs, ArgMD);
282 
283       Modified = true;
284     }
285     if (!Modified) {
286       return std::make_tuple(nullptr, nullptr);
287     }
288 
289     // Create function with new signature and clone the old body into it.
290     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
291     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
292     ValueToValueMapTy VMap;
293     auto NewFArgIt = NewF->arg_begin();
294     for (auto &Arg: F->args()) {
295       auto ArgName = Arg.getName();
296       NewFArgIt->setName(ArgName);
297       VMap[&Arg] = &(*NewFArgIt++);
298       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
299         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
300         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
301       }
302     }
303     SmallVector<ReturnInst*, 8> Returns;
304     CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
305                       Returns);
306 
307     // Build new MDNode.
308     SmallVector<Metadata *, 6> KernelMDArgs;
309     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
310     for (const MDVector &MDV : NewArgMDs.ArgVector)
311       KernelMDArgs.push_back(MDNode::get(*Context, MDV));
312     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
313 
314     return std::make_tuple(NewF, NewMDNode);
315   }
316 
317   bool transformKernels(Module &M) {
318     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
319     if (!KernelsMDNode)
320       return false;
321 
322     bool Modified = false;
323     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
324       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
325       Function *F = GetFunctionFromMDNode(KernelMDNode);
326       if (!F)
327         continue;
328 
329       Function *NewF;
330       MDNode *NewMDNode;
331       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
332       if (NewF) {
333         // Replace old function and metadata with new ones.
334         F->eraseFromParent();
335         M.getFunctionList().push_back(NewF);
336         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
337                               NewF->getAttributes());
338         KernelsMDNode->setOperand(i, NewMDNode);
339 
340         F = NewF;
341         KernelMDNode = NewMDNode;
342         Modified = true;
343       }
344 
345       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
346     }
347 
348     return Modified;
349   }
350 
351 public:
352   R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
353 
354   bool runOnModule(Module &M) override {
355     Context = &M.getContext();
356     Int32Type = Type::getInt32Ty(M.getContext());
357     ImageSizeType = ArrayType::get(Int32Type, 3);
358     ImageFormatType = ArrayType::get(Int32Type, 2);
359 
360     return transformKernels(M);
361   }
362 
363   StringRef getPassName() const override {
364     return "R600 OpenCL Image Type Pass";
365   }
366 };
367 
368 } // end anonymous namespace
369 
370 char R600OpenCLImageTypeLoweringPass::ID = 0;
371 
372 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
373   return new R600OpenCLImageTypeLoweringPass();
374 }
375