1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// This pass resolves calls to OpenCL image attribute, image resource ID and 11 /// sampler resource ID getter functions. 12 /// 13 /// Image attributes (size and format) are expected to be passed to the kernel 14 /// as kernel arguments immediately following the image argument itself, 15 /// therefore this pass adds image size and format arguments to the kernel 16 /// functions in the module. The kernel functions with image arguments are 17 /// re-created using the new signature. The new arguments are added to the 18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format". 19 /// Note: this pass may invalidate pointers to functions. 20 /// 21 /// Resource IDs of read-only images, write-only images and samplers are 22 /// defined to be their index among the kernel arguments of the same 23 /// type and access qualifier. 24 // 25 //===----------------------------------------------------------------------===// 26 27 #include "R600.h" 28 #include "llvm/ADT/SmallVector.h" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/IR/Constants.h" 31 #include "llvm/IR/Function.h" 32 #include "llvm/IR/Instructions.h" 33 #include "llvm/IR/Metadata.h" 34 #include "llvm/Pass.h" 35 #include "llvm/Transforms/Utils/Cloning.h" 36 37 using namespace llvm; 38 39 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size"; 40 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format"; 41 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id"; 42 static StringRef GetSamplerResourceIDFunc = 43 "llvm.OpenCL.sampler.get.resource.id"; 44 45 static StringRef ImageSizeArgMDType = "__llvm_image_size"; 46 static StringRef ImageFormatArgMDType = "__llvm_image_format"; 47 48 static StringRef KernelsMDNodeName = "opencl.kernels"; 49 static StringRef KernelArgMDNodeNames[] = { 50 "kernel_arg_addr_space", 51 "kernel_arg_access_qual", 52 "kernel_arg_type", 53 "kernel_arg_base_type", 54 "kernel_arg_type_qual"}; 55 static const unsigned NumKernelArgMDNodes = 5; 56 57 namespace { 58 59 using MDVector = SmallVector<Metadata *, 8>; 60 struct KernelArgMD { 61 MDVector ArgVector[NumKernelArgMDNodes]; 62 }; 63 64 } // end anonymous namespace 65 66 static inline bool 67 IsImageType(StringRef TypeString) { 68 return TypeString == "image2d_t" || TypeString == "image3d_t"; 69 } 70 71 static inline bool 72 IsSamplerType(StringRef TypeString) { 73 return TypeString == "sampler_t"; 74 } 75 76 static Function * 77 GetFunctionFromMDNode(MDNode *Node) { 78 if (!Node) 79 return nullptr; 80 81 size_t NumOps = Node->getNumOperands(); 82 if (NumOps != NumKernelArgMDNodes + 1) 83 return nullptr; 84 85 auto F = mdconst::dyn_extract<Function>(Node->getOperand(0)); 86 if (!F) 87 return nullptr; 88 89 // Validation checks. 90 size_t ExpectNumArgNodeOps = F->arg_size() + 1; 91 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) { 92 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1)); 93 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps) 94 return nullptr; 95 if (!ArgNode->getOperand(0)) 96 return nullptr; 97 98 // FIXME: It should be possible to do image lowering when some metadata 99 // args missing or not in the expected order. 100 MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0)); 101 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i]) 102 return nullptr; 103 } 104 105 return F; 106 } 107 108 static StringRef 109 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 110 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2)); 111 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString(); 112 } 113 114 static StringRef 115 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 116 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3)); 117 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString(); 118 } 119 120 static MDVector 121 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) { 122 MDVector Res; 123 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 124 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1)); 125 Res.push_back(Node->getOperand(OpIdx)); 126 } 127 return Res; 128 } 129 130 static void 131 PushArgMD(KernelArgMD &MD, const MDVector &V) { 132 assert(V.size() == NumKernelArgMDNodes); 133 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 134 MD.ArgVector[i].push_back(V[i]); 135 } 136 } 137 138 namespace { 139 140 class R600OpenCLImageTypeLoweringPass : public ModulePass { 141 static char ID; 142 143 LLVMContext *Context; 144 Type *Int32Type; 145 Type *ImageSizeType; 146 Type *ImageFormatType; 147 SmallVector<Instruction *, 4> InstsToErase; 148 149 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID, 150 Argument &ImageSizeArg, 151 Argument &ImageFormatArg) { 152 bool Modified = false; 153 154 for (auto &Use : ImageArg.uses()) { 155 auto Inst = dyn_cast<CallInst>(Use.getUser()); 156 if (!Inst) { 157 continue; 158 } 159 160 Function *F = Inst->getCalledFunction(); 161 if (!F) 162 continue; 163 164 Value *Replacement = nullptr; 165 StringRef Name = F->getName(); 166 if (Name.startswith(GetImageResourceIDFunc)) { 167 Replacement = ConstantInt::get(Int32Type, ResourceID); 168 } else if (Name.startswith(GetImageSizeFunc)) { 169 Replacement = &ImageSizeArg; 170 } else if (Name.startswith(GetImageFormatFunc)) { 171 Replacement = &ImageFormatArg; 172 } else { 173 continue; 174 } 175 176 Inst->replaceAllUsesWith(Replacement); 177 InstsToErase.push_back(Inst); 178 Modified = true; 179 } 180 181 return Modified; 182 } 183 184 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) { 185 bool Modified = false; 186 187 for (const auto &Use : SamplerArg.uses()) { 188 auto Inst = dyn_cast<CallInst>(Use.getUser()); 189 if (!Inst) { 190 continue; 191 } 192 193 Function *F = Inst->getCalledFunction(); 194 if (!F) 195 continue; 196 197 Value *Replacement = nullptr; 198 StringRef Name = F->getName(); 199 if (Name == GetSamplerResourceIDFunc) { 200 Replacement = ConstantInt::get(Int32Type, ResourceID); 201 } else { 202 continue; 203 } 204 205 Inst->replaceAllUsesWith(Replacement); 206 InstsToErase.push_back(Inst); 207 Modified = true; 208 } 209 210 return Modified; 211 } 212 213 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) { 214 uint32_t NumReadOnlyImageArgs = 0; 215 uint32_t NumWriteOnlyImageArgs = 0; 216 uint32_t NumSamplerArgs = 0; 217 218 bool Modified = false; 219 InstsToErase.clear(); 220 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) { 221 Argument &Arg = *ArgI; 222 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo()); 223 224 // Handle image types. 225 if (IsImageType(Type)) { 226 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo()); 227 uint32_t ResourceID; 228 if (AccessQual == "read_only") { 229 ResourceID = NumReadOnlyImageArgs++; 230 } else if (AccessQual == "write_only") { 231 ResourceID = NumWriteOnlyImageArgs++; 232 } else { 233 llvm_unreachable("Wrong image access qualifier."); 234 } 235 236 Argument &SizeArg = *(++ArgI); 237 Argument &FormatArg = *(++ArgI); 238 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg); 239 240 // Handle sampler type. 241 } else if (IsSamplerType(Type)) { 242 uint32_t ResourceID = NumSamplerArgs++; 243 Modified |= replaceSamplerUses(Arg, ResourceID); 244 } 245 } 246 for (unsigned i = 0; i < InstsToErase.size(); ++i) { 247 InstsToErase[i]->eraseFromParent(); 248 } 249 250 return Modified; 251 } 252 253 std::tuple<Function *, MDNode *> 254 addImplicitArgs(Function *F, MDNode *KernelMDNode) { 255 bool Modified = false; 256 257 FunctionType *FT = F->getFunctionType(); 258 SmallVector<Type *, 8> ArgTypes; 259 260 // Metadata operands for new MDNode. 261 KernelArgMD NewArgMDs; 262 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0)); 263 264 // Add implicit arguments to the signature. 265 for (unsigned i = 0; i < FT->getNumParams(); ++i) { 266 ArgTypes.push_back(FT->getParamType(i)); 267 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1); 268 PushArgMD(NewArgMDs, ArgMD); 269 270 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i))) 271 continue; 272 273 // Add size implicit argument. 274 ArgTypes.push_back(ImageSizeType); 275 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType); 276 PushArgMD(NewArgMDs, ArgMD); 277 278 // Add format implicit argument. 279 ArgTypes.push_back(ImageFormatType); 280 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType); 281 PushArgMD(NewArgMDs, ArgMD); 282 283 Modified = true; 284 } 285 if (!Modified) { 286 return std::make_tuple(nullptr, nullptr); 287 } 288 289 // Create function with new signature and clone the old body into it. 290 auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false); 291 auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName()); 292 ValueToValueMapTy VMap; 293 auto NewFArgIt = NewF->arg_begin(); 294 for (auto &Arg: F->args()) { 295 auto ArgName = Arg.getName(); 296 NewFArgIt->setName(ArgName); 297 VMap[&Arg] = &(*NewFArgIt++); 298 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) { 299 (NewFArgIt++)->setName(Twine("__size_") + ArgName); 300 (NewFArgIt++)->setName(Twine("__format_") + ArgName); 301 } 302 } 303 SmallVector<ReturnInst*, 8> Returns; 304 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, 305 Returns); 306 307 // Build new MDNode. 308 SmallVector<Metadata *, 6> KernelMDArgs; 309 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF)); 310 for (const MDVector &MDV : NewArgMDs.ArgVector) 311 KernelMDArgs.push_back(MDNode::get(*Context, MDV)); 312 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs); 313 314 return std::make_tuple(NewF, NewMDNode); 315 } 316 317 bool transformKernels(Module &M) { 318 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName); 319 if (!KernelsMDNode) 320 return false; 321 322 bool Modified = false; 323 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) { 324 MDNode *KernelMDNode = KernelsMDNode->getOperand(i); 325 Function *F = GetFunctionFromMDNode(KernelMDNode); 326 if (!F) 327 continue; 328 329 Function *NewF; 330 MDNode *NewMDNode; 331 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode); 332 if (NewF) { 333 // Replace old function and metadata with new ones. 334 F->eraseFromParent(); 335 M.getFunctionList().push_back(NewF); 336 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(), 337 NewF->getAttributes()); 338 KernelsMDNode->setOperand(i, NewMDNode); 339 340 F = NewF; 341 KernelMDNode = NewMDNode; 342 Modified = true; 343 } 344 345 Modified |= replaceImageAndSamplerUses(F, KernelMDNode); 346 } 347 348 return Modified; 349 } 350 351 public: 352 R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {} 353 354 bool runOnModule(Module &M) override { 355 Context = &M.getContext(); 356 Int32Type = Type::getInt32Ty(M.getContext()); 357 ImageSizeType = ArrayType::get(Int32Type, 3); 358 ImageFormatType = ArrayType::get(Int32Type, 2); 359 360 return transformKernels(M); 361 } 362 363 StringRef getPassName() const override { 364 return "R600 OpenCL Image Type Pass"; 365 } 366 }; 367 368 } // end anonymous namespace 369 370 char R600OpenCLImageTypeLoweringPass::ID = 0; 371 372 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() { 373 return new R600OpenCLImageTypeLoweringPass(); 374 } 375