1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// This pass resolves calls to OpenCL image attribute, image resource ID and 11 /// sampler resource ID getter functions. 12 /// 13 /// Image attributes (size and format) are expected to be passed to the kernel 14 /// as kernel arguments immediately following the image argument itself, 15 /// therefore this pass adds image size and format arguments to the kernel 16 /// functions in the module. The kernel functions with image arguments are 17 /// re-created using the new signature. The new arguments are added to the 18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format". 19 /// Note: this pass may invalidate pointers to functions. 20 /// 21 /// Resource IDs of read-only images, write-only images and samplers are 22 /// defined to be their index among the kernel arguments of the same 23 /// type and access qualifier. 24 // 25 //===----------------------------------------------------------------------===// 26 27 #include "AMDGPU.h" 28 #include "llvm/ADT/SmallVector.h" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/ADT/Twine.h" 31 #include "llvm/IR/Argument.h" 32 #include "llvm/IR/DerivedTypes.h" 33 #include "llvm/IR/Constants.h" 34 #include "llvm/IR/Function.h" 35 #include "llvm/IR/Instruction.h" 36 #include "llvm/IR/Instructions.h" 37 #include "llvm/IR/Metadata.h" 38 #include "llvm/IR/Module.h" 39 #include "llvm/IR/Type.h" 40 #include "llvm/IR/Use.h" 41 #include "llvm/IR/User.h" 42 #include "llvm/Pass.h" 43 #include "llvm/Support/Casting.h" 44 #include "llvm/Support/ErrorHandling.h" 45 #include "llvm/Transforms/Utils/Cloning.h" 46 #include "llvm/Transforms/Utils/ValueMapper.h" 47 #include <cassert> 48 #include <cstddef> 49 #include <cstdint> 50 #include <tuple> 51 52 using namespace llvm; 53 54 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size"; 55 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format"; 56 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id"; 57 static StringRef GetSamplerResourceIDFunc = 58 "llvm.OpenCL.sampler.get.resource.id"; 59 60 static StringRef ImageSizeArgMDType = "__llvm_image_size"; 61 static StringRef ImageFormatArgMDType = "__llvm_image_format"; 62 63 static StringRef KernelsMDNodeName = "opencl.kernels"; 64 static StringRef KernelArgMDNodeNames[] = { 65 "kernel_arg_addr_space", 66 "kernel_arg_access_qual", 67 "kernel_arg_type", 68 "kernel_arg_base_type", 69 "kernel_arg_type_qual"}; 70 static const unsigned NumKernelArgMDNodes = 5; 71 72 namespace { 73 74 using MDVector = SmallVector<Metadata *, 8>; 75 struct KernelArgMD { 76 MDVector ArgVector[NumKernelArgMDNodes]; 77 }; 78 79 } // end anonymous namespace 80 81 static inline bool 82 IsImageType(StringRef TypeString) { 83 return TypeString == "image2d_t" || TypeString == "image3d_t"; 84 } 85 86 static inline bool 87 IsSamplerType(StringRef TypeString) { 88 return TypeString == "sampler_t"; 89 } 90 91 static Function * 92 GetFunctionFromMDNode(MDNode *Node) { 93 if (!Node) 94 return nullptr; 95 96 size_t NumOps = Node->getNumOperands(); 97 if (NumOps != NumKernelArgMDNodes + 1) 98 return nullptr; 99 100 auto F = mdconst::dyn_extract<Function>(Node->getOperand(0)); 101 if (!F) 102 return nullptr; 103 104 // Sanity checks. 105 size_t ExpectNumArgNodeOps = F->arg_size() + 1; 106 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) { 107 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1)); 108 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps) 109 return nullptr; 110 if (!ArgNode->getOperand(0)) 111 return nullptr; 112 113 // FIXME: It should be possible to do image lowering when some metadata 114 // args missing or not in the expected order. 115 MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0)); 116 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i]) 117 return nullptr; 118 } 119 120 return F; 121 } 122 123 static StringRef 124 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 125 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2)); 126 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString(); 127 } 128 129 static StringRef 130 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 131 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3)); 132 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString(); 133 } 134 135 static MDVector 136 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) { 137 MDVector Res; 138 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 139 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1)); 140 Res.push_back(Node->getOperand(OpIdx)); 141 } 142 return Res; 143 } 144 145 static void 146 PushArgMD(KernelArgMD &MD, const MDVector &V) { 147 assert(V.size() == NumKernelArgMDNodes); 148 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 149 MD.ArgVector[i].push_back(V[i]); 150 } 151 } 152 153 namespace { 154 155 class R600OpenCLImageTypeLoweringPass : public ModulePass { 156 static char ID; 157 158 LLVMContext *Context; 159 Type *Int32Type; 160 Type *ImageSizeType; 161 Type *ImageFormatType; 162 SmallVector<Instruction *, 4> InstsToErase; 163 164 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID, 165 Argument &ImageSizeArg, 166 Argument &ImageFormatArg) { 167 bool Modified = false; 168 169 for (auto &Use : ImageArg.uses()) { 170 auto Inst = dyn_cast<CallInst>(Use.getUser()); 171 if (!Inst) { 172 continue; 173 } 174 175 Function *F = Inst->getCalledFunction(); 176 if (!F) 177 continue; 178 179 Value *Replacement = nullptr; 180 StringRef Name = F->getName(); 181 if (Name.startswith(GetImageResourceIDFunc)) { 182 Replacement = ConstantInt::get(Int32Type, ResourceID); 183 } else if (Name.startswith(GetImageSizeFunc)) { 184 Replacement = &ImageSizeArg; 185 } else if (Name.startswith(GetImageFormatFunc)) { 186 Replacement = &ImageFormatArg; 187 } else { 188 continue; 189 } 190 191 Inst->replaceAllUsesWith(Replacement); 192 InstsToErase.push_back(Inst); 193 Modified = true; 194 } 195 196 return Modified; 197 } 198 199 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) { 200 bool Modified = false; 201 202 for (const auto &Use : SamplerArg.uses()) { 203 auto Inst = dyn_cast<CallInst>(Use.getUser()); 204 if (!Inst) { 205 continue; 206 } 207 208 Function *F = Inst->getCalledFunction(); 209 if (!F) 210 continue; 211 212 Value *Replacement = nullptr; 213 StringRef Name = F->getName(); 214 if (Name == GetSamplerResourceIDFunc) { 215 Replacement = ConstantInt::get(Int32Type, ResourceID); 216 } else { 217 continue; 218 } 219 220 Inst->replaceAllUsesWith(Replacement); 221 InstsToErase.push_back(Inst); 222 Modified = true; 223 } 224 225 return Modified; 226 } 227 228 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) { 229 uint32_t NumReadOnlyImageArgs = 0; 230 uint32_t NumWriteOnlyImageArgs = 0; 231 uint32_t NumSamplerArgs = 0; 232 233 bool Modified = false; 234 InstsToErase.clear(); 235 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) { 236 Argument &Arg = *ArgI; 237 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo()); 238 239 // Handle image types. 240 if (IsImageType(Type)) { 241 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo()); 242 uint32_t ResourceID; 243 if (AccessQual == "read_only") { 244 ResourceID = NumReadOnlyImageArgs++; 245 } else if (AccessQual == "write_only") { 246 ResourceID = NumWriteOnlyImageArgs++; 247 } else { 248 llvm_unreachable("Wrong image access qualifier."); 249 } 250 251 Argument &SizeArg = *(++ArgI); 252 Argument &FormatArg = *(++ArgI); 253 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg); 254 255 // Handle sampler type. 256 } else if (IsSamplerType(Type)) { 257 uint32_t ResourceID = NumSamplerArgs++; 258 Modified |= replaceSamplerUses(Arg, ResourceID); 259 } 260 } 261 for (unsigned i = 0; i < InstsToErase.size(); ++i) { 262 InstsToErase[i]->eraseFromParent(); 263 } 264 265 return Modified; 266 } 267 268 std::tuple<Function *, MDNode *> 269 addImplicitArgs(Function *F, MDNode *KernelMDNode) { 270 bool Modified = false; 271 272 FunctionType *FT = F->getFunctionType(); 273 SmallVector<Type *, 8> ArgTypes; 274 275 // Metadata operands for new MDNode. 276 KernelArgMD NewArgMDs; 277 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0)); 278 279 // Add implicit arguments to the signature. 280 for (unsigned i = 0; i < FT->getNumParams(); ++i) { 281 ArgTypes.push_back(FT->getParamType(i)); 282 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1); 283 PushArgMD(NewArgMDs, ArgMD); 284 285 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i))) 286 continue; 287 288 // Add size implicit argument. 289 ArgTypes.push_back(ImageSizeType); 290 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType); 291 PushArgMD(NewArgMDs, ArgMD); 292 293 // Add format implicit argument. 294 ArgTypes.push_back(ImageFormatType); 295 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType); 296 PushArgMD(NewArgMDs, ArgMD); 297 298 Modified = true; 299 } 300 if (!Modified) { 301 return std::make_tuple(nullptr, nullptr); 302 } 303 304 // Create function with new signature and clone the old body into it. 305 auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false); 306 auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName()); 307 ValueToValueMapTy VMap; 308 auto NewFArgIt = NewF->arg_begin(); 309 for (auto &Arg: F->args()) { 310 auto ArgName = Arg.getName(); 311 NewFArgIt->setName(ArgName); 312 VMap[&Arg] = &(*NewFArgIt++); 313 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) { 314 (NewFArgIt++)->setName(Twine("__size_") + ArgName); 315 (NewFArgIt++)->setName(Twine("__format_") + ArgName); 316 } 317 } 318 SmallVector<ReturnInst*, 8> Returns; 319 CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns); 320 321 // Build new MDNode. 322 SmallVector<Metadata *, 6> KernelMDArgs; 323 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF)); 324 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) 325 KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i])); 326 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs); 327 328 return std::make_tuple(NewF, NewMDNode); 329 } 330 331 bool transformKernels(Module &M) { 332 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName); 333 if (!KernelsMDNode) 334 return false; 335 336 bool Modified = false; 337 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) { 338 MDNode *KernelMDNode = KernelsMDNode->getOperand(i); 339 Function *F = GetFunctionFromMDNode(KernelMDNode); 340 if (!F) 341 continue; 342 343 Function *NewF; 344 MDNode *NewMDNode; 345 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode); 346 if (NewF) { 347 // Replace old function and metadata with new ones. 348 F->eraseFromParent(); 349 M.getFunctionList().push_back(NewF); 350 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(), 351 NewF->getAttributes()); 352 KernelsMDNode->setOperand(i, NewMDNode); 353 354 F = NewF; 355 KernelMDNode = NewMDNode; 356 Modified = true; 357 } 358 359 Modified |= replaceImageAndSamplerUses(F, KernelMDNode); 360 } 361 362 return Modified; 363 } 364 365 public: 366 R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {} 367 368 bool runOnModule(Module &M) override { 369 Context = &M.getContext(); 370 Int32Type = Type::getInt32Ty(M.getContext()); 371 ImageSizeType = ArrayType::get(Int32Type, 3); 372 ImageFormatType = ArrayType::get(Int32Type, 2); 373 374 return transformKernels(M); 375 } 376 377 StringRef getPassName() const override { 378 return "R600 OpenCL Image Type Pass"; 379 } 380 }; 381 382 } // end anonymous namespace 383 384 char R600OpenCLImageTypeLoweringPass::ID = 0; 385 386 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() { 387 return new R600OpenCLImageTypeLoweringPass(); 388 } 389