1 /* pocl-ptx-gen.cc - PTX code generation functions
2 
3    Copyright (c) 2016-2017 James Price / University of Bristol
4 
5    Permission is hereby granted, free of charge, to any person obtaining a copy
6    of this software and associated documentation files (the "Software"), to deal
7    in the Software without restriction, including without limitation the rights
8    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9    copies of the Software, and to permit persons to whom the Software is
10    furnished to do so, subject to the following conditions:
11 
12    The above copyright notice and this permission notice shall be included in
13    all copies or substantial portions of the Software.
14 
15    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21    THE SOFTWARE.
22 */
23 
24 #include "config.h"
25 
26 #include "LLVMUtils.h"
27 #include "common.h"
28 #include "pocl-ptx-gen.h"
29 #include "pocl.h"
30 #include "pocl_file_util.h"
31 #include "pocl_runtime_config.h"
32 
33 #include "llvm/Bitcode/BitcodeReader.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/InstIterator.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/LegacyPassManager.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Verifier.h"
40 #include "llvm/Linker/Linker.h"
41 #include "llvm/Support/FileSystem.h"
42 #include "llvm/Support/TargetRegistry.h"
43 #include "llvm/Support/TargetSelect.h"
44 #ifndef LLVM_OLDER_THAN_11_0
45 #include "llvm/Support/Alignment.h"
46 #endif
47 #include "llvm/Target/TargetMachine.h"
48 #include "llvm/Target/TargetOptions.h"
49 #include "llvm/Transforms/IPO.h"
50 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
51 #include "llvm/Transforms/Utils/Cloning.h"
52 
53 #include <set>
54 
55 namespace llvm {
56 extern ModulePass *createNVVMReflectPass(const StringMap<int> &Mapping);
57 }
58 
59 static void addKernelAnnotations(llvm::Module *Module, const char *KernelName);
60 static void fixConstantMemArgs(llvm::Module *Module, const char *KernelName);
61 static void fixLocalMemArgs(llvm::Module *Module, const char *KernelName);
62 static void fixPrintF(llvm::Module *Module);
63 static void handleGetWorkDim(llvm::Module *Module, const char *KernelName);
64 static void linkLibDevice(llvm::Module *Module, const char *KernelName,
65                           const char *LibDevicePath);
66 static void mapLibDeviceCalls(llvm::Module *Module);
67 
pocl_ptx_gen(const char * BitcodeFilename,const char * PTXFilename,const char * KernelName,const char * Arch,const char * LibDevicePath,int HasOffsets)68 int pocl_ptx_gen(const char *BitcodeFilename, const char *PTXFilename,
69                  const char *KernelName, const char *Arch,
70                  const char *LibDevicePath, int HasOffsets) {
71   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> Buffer =
72       llvm::MemoryBuffer::getFile(BitcodeFilename);
73   if (!Buffer) {
74     POCL_MSG_ERR("[CUDA] ptx-gen: failed to open bitcode file\n");
75     return 1;
76   }
77 
78   // Load the LLVM bitcode module.
79   llvm::LLVMContext Context;
80   llvm::Expected<std::unique_ptr<llvm::Module>> Module =
81       parseBitcodeFile(Buffer->get()->getMemBufferRef(), Context);
82   if (!Module) {
83     POCL_MSG_ERR("[CUDA] ptx-gen: failed to load bitcode\n");
84     return 1;
85   }
86 
87   // Apply transforms to prepare for lowering to PTX.
88   fixPrintF(Module->get());
89   fixConstantMemArgs(Module->get(), KernelName);
90   fixLocalMemArgs(Module->get(), KernelName);
91   handleGetWorkDim(Module->get(), KernelName);
92   addKernelAnnotations(Module->get(), KernelName);
93   mapLibDeviceCalls(Module->get());
94   linkLibDevice(Module->get(), KernelName, LibDevicePath);
95   if (pocl_get_bool_option("POCL_CUDA_DUMP_NVVM", 0)) {
96     std::string ModuleString;
97     llvm::raw_string_ostream ModuleStringStream(ModuleString);
98     (*Module)->print(ModuleStringStream, NULL);
99     POCL_MSG_PRINT_INFO("NVVM module:\n%s\n", ModuleString.c_str());
100   }
101 
102   std::string Error;
103 
104   // Verify module.
105   if (pocl_get_bool_option("POCL_CUDA_VERIFY_MODULE", 1)) {
106     llvm::raw_string_ostream Errs(Error);
107     if (llvm::verifyModule(*Module->get(), &Errs)) {
108       POCL_MSG_ERR("\n%s\n", Error.c_str());
109       POCL_ABORT("[CUDA] ptx-gen: module verification failed\n");
110     }
111   }
112 
113 #ifdef LLVM_OLDER_THAN_11_0
114   llvm::StringRef Triple =
115       (sizeof(void *) == 8) ? "nvptx64-nvidia-cuda" : "nvptx-nvidia-cuda";
116 #else
117   std::string Triple =
118       (sizeof(void *) == 8) ? "nvptx64-nvidia-cuda" : "nvptx-nvidia-cuda";
119 #endif
120 
121   // Get NVPTX target.
122   const llvm::Target *Target =
123       llvm::TargetRegistry::lookupTarget(Triple, Error);
124 
125   if (!Target) {
126     POCL_MSG_ERR("[CUDA] ptx-gen: failed to get target\n");
127     POCL_MSG_ERR("%s\n", Error.c_str());
128     return 1;
129   }
130 
131   // TODO: Set options?
132   llvm::TargetOptions Options;
133 
134   // TODO: CPU and features?
135   std::unique_ptr<llvm::TargetMachine> Machine(
136       Target->createTargetMachine(Triple, Arch, "+ptx40", Options, llvm::None));
137 
138   llvm::legacy::PassManager Passes;
139 
140   // Add pass to emit PTX.
141   llvm::SmallVector<char, 4096> Data;
142   llvm::raw_svector_ostream PTXStream(Data);
143   if (Machine->addPassesToEmitFile(Passes, PTXStream,
144 #ifndef LLVM_OLDER_THAN_7_0
145                                    nullptr,
146 #endif
147 #ifdef LLVM_OLDER_THAN_10_0
148                                    llvm::TargetMachine::CGFT_AssemblyFile)) {
149 #else
150                                    llvm::CGFT_AssemblyFile)) {
151 #endif
152     POCL_MSG_ERR("[CUDA] ptx-gen: failed to add passes\n");
153     return 1;
154   }
155 
156   // Run passes.
157   Passes.run(**Module);
158 
159 #ifdef LLVM_OLDER_THAN_11_0
160   std::string PTX = PTXStream.str();
161   return pocl_write_file(PTXFilename, PTX.c_str(), PTX.size(), 0, 0);
162 #else
163   llvm::StringRef PTX = PTXStream.str();
164   return pocl_write_file(PTXFilename, PTX.data(), PTX.size(), 0, 0);
165 #endif
166 }
167 
168 // Add the metadata needed to mark a function as a kernel in PTX.
169 void addKernelAnnotations(llvm::Module *Module, const char *KernelName) {
170   llvm::LLVMContext &Context = Module->getContext();
171 
172   // Remove existing nvvm.annotations metadata since it is sometimes corrupt.
173   auto *Annotations = Module->getNamedMetadata("nvvm.annotations");
174   if (Annotations)
175     Annotations->eraseFromParent();
176 
177   // Add nvvm.annotations metadata to mark kernel entry point.
178   Annotations = Module->getOrInsertNamedMetadata("nvvm.annotations");
179 
180   // Get handle to function.
181   auto *Function = Module->getFunction(KernelName);
182   if (!Function)
183     POCL_ABORT("[CUDA] ptx-gen: kernel function not found in module\n");
184 
185   // Create metadata.
186   llvm::Constant *One =
187       llvm::ConstantInt::getSigned(llvm::Type::getInt32Ty(Context), 1);
188   llvm::Metadata *FuncMD = llvm::ValueAsMetadata::get(Function);
189   llvm::Metadata *NameMD = llvm::MDString::get(Context, "kernel");
190   llvm::Metadata *OneMD = llvm::ConstantAsMetadata::get(One);
191 
192   llvm::MDNode *Node = llvm::MDNode::get(Context, {FuncMD, NameMD, OneMD});
193   Annotations->addOperand(Node);
194 }
195 
196 // PTX doesn't support variadic functions, so we need to modify the IR to
197 // support printf. The vprintf system call that is provided is described here:
198 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
199 // Essentially, the variadic list of arguments is replaced with a single array
200 // instead.
201 //
202 // This function changes the prototype of printf to take an array instead
203 // of a variadic argument list. It updates the function body to read from
204 // this array to retrieve each argument instead of using the dummy __cl_va_arg
205 // function. We then visit each printf callsite and generate the argument
206 // array to pass instead of the variadic list.
207 void fixPrintF(llvm::Module *Module) {
208   llvm::Function *OldPrintF = Module->getFunction("printf");
209   if (!OldPrintF)
210     return;
211 
212   llvm::LLVMContext &Context = Module->getContext();
213   llvm::Type *I32 = llvm::Type::getInt32Ty(Context);
214   llvm::Type *I64 = llvm::Type::getInt64Ty(Context);
215   llvm::Type *I64Ptr = llvm::PointerType::get(I64, 0);
216   llvm::Type *FormatType = OldPrintF->getFunctionType()->getParamType(0);
217 
218   // Remove calls to va_start and va_end.
219   pocl::eraseFunctionAndCallers(Module->getFunction("llvm.va_start"));
220   pocl::eraseFunctionAndCallers(Module->getFunction("llvm.va_end"));
221 
222   // Create new non-variadic printf function.
223   llvm::Type *ReturnType = OldPrintF->getReturnType();
224   llvm::FunctionType *NewPrintfType =
225       llvm::FunctionType::get(ReturnType, {FormatType, I64Ptr}, false);
226   llvm::Function *NewPrintF = llvm::Function::Create(
227       NewPrintfType, OldPrintF->getLinkage(), "", Module);
228   NewPrintF->takeName(OldPrintF);
229 
230   // Take function body from old function.
231   NewPrintF->getBasicBlockList().splice(NewPrintF->begin(),
232                                         OldPrintF->getBasicBlockList());
233 
234   // Create i32 to hold current argument index.
235 #ifdef LLVM_OLDER_THAN_11_0
236   llvm::AllocaInst *ArgIndexPtr =
237       new llvm::AllocaInst(I32, 0, llvm::ConstantInt::get(I32, 1));
238   ArgIndexPtr->insertBefore(&*NewPrintF->begin()->begin());
239 #else
240   llvm::AllocaInst *ArgIndexPtr =
241       new llvm::AllocaInst(I32, 0, llvm::ConstantInt::get(I32, 1), llvm::Align(4));
242   ArgIndexPtr->insertBefore(&*NewPrintF->begin()->begin());
243 #endif
244 
245 #ifdef LLVM_OLDER_THAN_11_0
246   llvm::StoreInst *ArgIndexInit =
247       new llvm::StoreInst(llvm::ConstantInt::get(I32, 0), ArgIndexPtr);
248   ArgIndexInit->insertAfter(ArgIndexPtr);
249 #else
250   llvm::StoreInst *ArgIndexInit =
251       new llvm::StoreInst(llvm::ConstantInt::get(I32, 0), ArgIndexPtr, false, llvm::Align(4));
252   ArgIndexInit->insertAfter(ArgIndexPtr);
253 #endif
254 
255   // Replace calls to _cl_va_arg with reads from new i64 array argument.
256   llvm::Function *VaArgFunc = Module->getFunction("__cl_va_arg");
257   if (VaArgFunc) {
258     auto args = NewPrintF->arg_begin();
259     args++;
260     llvm::Argument *ArgsIn = args;
261     std::vector<llvm::Value *> VaArgCalls(VaArgFunc->user_begin(),
262                                           VaArgFunc->user_end());
263     for (auto &U : VaArgCalls) {
264       llvm::CallInst *Call = llvm::dyn_cast<llvm::CallInst>(U);
265       if (!Call)
266         continue;
267 
268       // Get current argument index.
269 #ifdef LLVM_OLDER_THAN_11_0
270       llvm::LoadInst *ArgIndex = new llvm::LoadInst(ArgIndexPtr);
271       ArgIndex->insertBefore(Call);
272 #else
273       llvm::LoadInst *ArgIndex = new llvm::LoadInst(I32, ArgIndexPtr, "poclCudaLoad", false, llvm::Align(4));
274       ArgIndex->insertBefore(Call);
275 #endif
276       // Get pointer to argument data.
277       llvm::Value *ArgOut = Call->getArgOperand(1);
278       llvm::GetElementPtrInst *ArgIn =
279           llvm::GetElementPtrInst::Create(I64, ArgsIn, {ArgIndex});
280       ArgIn->insertAfter(ArgIndex);
281 
282       // Cast ArgOut pointer to i64*.
283       llvm::BitCastInst *ArgOutBC = new llvm::BitCastInst(ArgOut, I64Ptr);
284       ArgOutBC->insertAfter(ArgIn);
285       ArgOut = ArgOutBC;
286 
287       // Load argument.
288 #ifdef LLVM_OLDER_THAN_11_0
289       llvm::LoadInst *ArgValue = new llvm::LoadInst(ArgIn);
290       ArgValue->insertAfter(ArgIn);
291       llvm::StoreInst *ArgStore = new llvm::StoreInst(ArgValue, ArgOut);
292       ArgStore->insertAfter(ArgOutBC);
293 #else
294       llvm::LoadInst *ArgValue = new llvm::LoadInst(I64, ArgIn, "poclCudaArgLoad", false, llvm::Align(8));
295       ArgValue->insertAfter(ArgIn);
296       llvm::StoreInst *ArgStore = new llvm::StoreInst(ArgValue, ArgOut, false, llvm::Align(8));
297       ArgStore->insertAfter(ArgOutBC);
298 #endif
299       // Increment argument index.
300       llvm::BinaryOperator *Inc = llvm::BinaryOperator::Create(
301           llvm::BinaryOperator::Add, ArgIndex, llvm::ConstantInt::get(I32, 1));
302       Inc->insertAfter(ArgIndex);
303 
304 #ifdef LLVM_OLDER_THAN_11_0
305       llvm::StoreInst *StoreInc = new llvm::StoreInst(Inc, ArgIndexPtr);
306       StoreInc->insertAfter(Inc);
307 #else
308       llvm::StoreInst *StoreInc = new llvm::StoreInst(Inc, ArgIndexPtr, false, llvm::Align(4));
309       StoreInc->insertAfter(Inc);
310 #endif
311       // Remove call to _cl_va_arg.
312       Call->eraseFromParent();
313     }
314 
315     // Remove function from module.
316     VaArgFunc->eraseFromParent();
317   }
318 
319   // Loop over function callers.
320   // Generate array of i64 arguments to replace variadic arguments/
321   std::vector<llvm::Value *> Callers(OldPrintF->user_begin(),
322                                      OldPrintF->user_end());
323   for (auto &U : Callers) {
324     llvm::CallInst *Call = llvm::dyn_cast<llvm::CallInst>(U);
325     if (!Call)
326       continue;
327 
328     unsigned NumArgs = Call->getNumArgOperands() - 1;
329     llvm::Value *Format = Call->getArgOperand(0);
330 
331     // Allocate array for arguments.
332     // TODO: Deal with vector arguments.
333 #ifdef LLVM_OLDER_THAN_11_0
334     llvm::AllocaInst *Args =
335         new llvm::AllocaInst(I64, 0, llvm::ConstantInt::get(I32, NumArgs));
336     Args->insertBefore(Call);
337 #else
338     llvm::AllocaInst *Args =
339         new llvm::AllocaInst(I64, 0, llvm::ConstantInt::get(I32, NumArgs), llvm::Align(8));
340     Args->insertBefore(Call);
341 #endif
342 
343     // Loop over arguments (skipping format).
344     for (unsigned A = 0; A < NumArgs; A++) {
345       llvm::Value *Arg = Call->getArgOperand(A + 1);
346       llvm::Type *ArgType = Arg->getType();
347 
348       // Cast pointers to the generic address space.
349       if (ArgType->isPointerTy() && ArgType->getPointerAddressSpace() != 0) {
350         llvm::CastInst *AddrSpaceCast =llvm::CastInst::CreatePointerCast(
351           Arg, ArgType->getPointerElementType()->getPointerTo());
352         AddrSpaceCast->insertBefore(Call);
353         Arg = AddrSpaceCast;
354         ArgType = Arg->getType();
355       }
356 
357       // Get pointer to argument in i64 array.
358       // TODO: promote arguments that are shorter than 32 bits.
359       llvm::Constant *ArgIndex = llvm::ConstantInt::get(I32, A);
360       llvm::Instruction *ArgPtr =
361           llvm::GetElementPtrInst::Create(I64, Args, {ArgIndex});
362       ArgPtr->insertBefore(Call);
363 
364       // Cast pointer to correct type if necessary.
365       if (ArgPtr->getType()->getPointerElementType() != ArgType) {
366         llvm::BitCastInst *ArgPtrBC =
367             new llvm::BitCastInst(ArgPtr, ArgType->getPointerTo(0));
368         ArgPtrBC->insertAfter(ArgPtr);
369         ArgPtr = ArgPtrBC;
370       }
371 
372       // Store argument to i64 array.
373 #ifdef LLVM_OLDER_THAN_11_0
374       llvm::StoreInst *Store = new llvm::StoreInst(Arg, ArgPtr);
375       Store->insertBefore(Call);
376 #else
377       llvm::StoreInst *Store = new llvm::StoreInst(Arg, ArgPtr, false, llvm::Align(8));
378       Store->insertBefore(Call);
379 #endif
380     }
381 
382     // Fix address space of undef format values.
383     if (Format->getValueID() == llvm::Value::UndefValueVal) {
384       Format = llvm::UndefValue::get(FormatType);
385     }
386 
387     // Replace call with new non-variadic function.
388     llvm::CallInst *NewCall = llvm::CallInst::Create(NewPrintF, {Format, Args});
389     NewCall->insertBefore(Call);
390     Call->replaceAllUsesWith(NewCall);
391     Call->eraseFromParent();
392   }
393 
394   // Update arguments.
395   llvm::Function::arg_iterator OldArg = OldPrintF->arg_begin();
396   llvm::Function::arg_iterator NewArg = NewPrintF->arg_begin();
397   NewArg->takeName(&*OldArg);
398   OldArg->replaceAllUsesWith(&*NewArg);
399 
400   // Remove old function.
401   OldPrintF->eraseFromParent();
402 
403   // Get handle to vprintf function.
404   llvm::Function *VPrintF = Module->getFunction("vprintf");
405   if (!VPrintF)
406     return;
407 
408   // If vprintf format address space is already generic, then we're done.
409   auto *VPrintFFormatType = VPrintF->getFunctionType()->getParamType(0);
410   if (VPrintFFormatType->getPointerAddressSpace() == 0)
411     return;
412 
413   // Change address space of vprintf format argument to generic.
414   auto *I8Ptr = llvm::PointerType::get(llvm::Type::getInt8Ty(Context), 0);
415   auto *NewVPrintFType =
416       llvm::FunctionType::get(VPrintF->getReturnType(), {I8Ptr, I8Ptr}, false);
417   auto *NewVPrintF =
418       llvm::Function::Create(NewVPrintFType, VPrintF->getLinkage(), "", Module);
419   NewVPrintF->takeName(VPrintF);
420 
421   // Update vprintf callers to pass format arguments in generic address space.
422   Callers.assign(VPrintF->user_begin(), VPrintF->user_end());
423   for (auto &U : Callers) {
424     llvm::CallInst *Call = llvm::dyn_cast<llvm::CallInst>(U);
425     if (!Call)
426       continue;
427 
428     llvm::Value *Format = Call->getArgOperand(0);
429     llvm::Type *FormatType = Format->getType();
430     if (FormatType->getPointerAddressSpace() != 0) {
431       // Cast address space to generic.
432       llvm::Type *NewFormatType =
433           FormatType->getPointerElementType()->getPointerTo(0);
434       llvm::AddrSpaceCastInst *FormatASC =
435           new llvm::AddrSpaceCastInst(Format, NewFormatType);
436       FormatASC->insertBefore(Call);
437       Call->setArgOperand(0, FormatASC);
438     }
439     Call->setCalledFunction(NewVPrintF);
440   }
441 
442   VPrintF->eraseFromParent();
443 }
444 
445 // Replace all load users of a scalar global variable with new value.
446 static void replaceScalarGlobalVar(llvm::Module *Module, const char *Name,
447                                    llvm::Value *NewValue) {
448   auto GlobalVar = Module->getGlobalVariable(Name);
449   if (!GlobalVar)
450     return;
451 
452   std::vector<llvm::Value *> Users(GlobalVar->user_begin(),
453                                    GlobalVar->user_end());
454   for (auto *U : Users) {
455     auto Load = llvm::dyn_cast<llvm::LoadInst>(U);
456     assert(Load && "Use of a scalar global variable is not a load");
457     Load->replaceAllUsesWith(NewValue);
458     Load->eraseFromParent();
459   }
460   GlobalVar->eraseFromParent();
461 }
462 
463 // Add an extra kernel argument for the dimensionality.
464 void handleGetWorkDim(llvm::Module *Module, const char *KernelName) {
465   llvm::Function *Function = Module->getFunction(KernelName);
466   if (!Function)
467     POCL_ABORT("[CUDA] ptx-gen: kernel function not found in module\n");
468 
469   // Add additional argument for the work item dimensionality.
470   llvm::FunctionType *FunctionType = Function->getFunctionType();
471   std::vector<llvm::Type *> ArgumentTypes(FunctionType->param_begin(),
472                                           FunctionType->param_end());
473   ArgumentTypes.push_back(llvm::Type::getInt32Ty(Module->getContext()));
474 
475   // Create new function.
476   llvm::FunctionType *NewFunctionType =
477       llvm::FunctionType::get(Function->getReturnType(), ArgumentTypes, false);
478   llvm::Function *NewFunction = llvm::Function::Create(
479       NewFunctionType, Function->getLinkage(), Function->getName(), Module);
480   NewFunction->takeName(Function);
481 
482   // Map function arguments.
483   llvm::ValueToValueMapTy VV;
484   llvm::Function::arg_iterator OldArg;
485   llvm::Function::arg_iterator NewArg;
486   for (OldArg = Function->arg_begin(), NewArg = NewFunction->arg_begin();
487        OldArg != Function->arg_end(); NewArg++, OldArg++) {
488     NewArg->takeName(&*OldArg);
489     VV[&*OldArg] = &*NewArg;
490   }
491 
492   // Clone function.
493   llvm::SmallVector<llvm::ReturnInst *, 1> RI;
494   CloneFunctionIntoAbs(NewFunction, Function, VV, RI);
495 
496   Function->eraseFromParent();
497 
498   auto WorkDimVar = Module->getGlobalVariable("_work_dim");
499   if (!WorkDimVar)
500     return;
501 
502   // Replace uses of the global offset variables with the new arguments.
503   NewArg->setName("work_dim");
504   replaceScalarGlobalVar(Module, "_work_dim", (&*NewArg++));
505 
506   // TODO: What if get_work_dim() is called from a non-kernel function?
507 }
508 
509 int findLibDevice(char LibDevicePath[PATH_MAX], const char *Arch) {
510   // Extract numeric portion of SM version.
511   char *End;
512   unsigned long SM = strtoul(Arch + 3, &End, 10);
513   if (!SM || strlen(End)) {
514     POCL_MSG_ERR("[CUDA] invalid GPU architecture %s\n", Arch);
515     return 1;
516   }
517 
518   // This mapping from SM version to libdevice library version is given here:
519   // http://docs.nvidia.com/cuda/libdevice-users-guide/basic-usage.html#version-selection
520   // This is no longer needed as of CUDA 9.
521   int LibDeviceSM = 0;
522   if (SM < 30)
523     LibDeviceSM = 20;
524   else if (SM == 30)
525     LibDeviceSM = 30;
526   else if (SM < 35)
527     LibDeviceSM = 20;
528   else if (SM <= 37)
529     LibDeviceSM = 35;
530   else if (SM < 50)
531     LibDeviceSM = 30;
532   else if (SM <= 53)
533     LibDeviceSM = 50;
534   else
535     LibDeviceSM = 30;
536 
537   const char *BasePath[] = {
538     pocl_get_string_option("POCL_CUDA_TOOLKIT_PATH", CUDA_TOOLKIT_ROOT_DIR),
539     "/usr/local/lib/cuda",
540     "/usr/local/lib",
541     "/usr/lib",
542   };
543 
544   static const char *NVVMPath[] = {
545     "/nvvm",
546     "/nvidia-cuda-toolkit",
547     "",
548   };
549 
550   static const char *PathFormat = "%s%s/libdevice/libdevice.10.bc";
551   static const char *OldPathFormat =
552       "%s%s/libdevice/libdevice.compute_%d.10.bc";
553 
554   // Search combinations of paths for the libdevice library.
555   for (auto bp : BasePath) {
556     for (auto np : NVVMPath) {
557       // Check for CUDA 9+ libdevice library.
558       size_t ps = snprintf(LibDevicePath, PATH_MAX - 1, PathFormat, bp, np);
559       LibDevicePath[ps] = '\0';
560       POCL_MSG_PRINT2(CUDA, __FUNCTION__, __LINE__,
561                       "looking for libdevice at '%s'\n", LibDevicePath);
562       if (pocl_exists(LibDevicePath)) {
563         POCL_MSG_PRINT2(CUDA, __FUNCTION__, __LINE__,
564                         "found libdevice at '%s'\n", LibDevicePath);
565         return 0;
566       }
567 
568       // Check for pre CUDA 9 libdevice library.
569       ps = snprintf(LibDevicePath, PATH_MAX - 1, OldPathFormat, bp, np,
570                     LibDeviceSM);
571       LibDevicePath[ps] = '\0';
572       POCL_MSG_PRINT2(CUDA, __FUNCTION__, __LINE__,
573                       "looking for libdevice at '%s'\n", LibDevicePath);
574       if (pocl_exists(LibDevicePath)) {
575         POCL_MSG_PRINT2(CUDA, __FUNCTION__, __LINE__,
576                         "found libdevice at '%s'\n", LibDevicePath);
577         return 0;
578       }
579     }
580   }
581 
582   return 1;
583 }
584 
585 // Link CUDA's libdevice bitcode library to provide implementations for most of
586 // the OpenCL math functions.
587 // TODO: Can we link libdevice into the kernel library at pocl build time?
588 // This would remove this runtime dependency on the CUDA toolkit.
589 // Had some issues with the earlier pocl LLVM passes crashing on the libdevice
590 // code - needs more investigation.
591 void linkLibDevice(llvm::Module *Module, const char *KernelName,
592                    const char *LibDevicePath) {
593   auto Buffer = llvm::MemoryBuffer::getFile(LibDevicePath);
594   if (!Buffer)
595     POCL_ABORT("[CUDA] failed to open libdevice library file\n");
596 
597   POCL_MSG_PRINT_INFO("loading libdevice from '%s'\n", LibDevicePath);
598 
599   // Load libdevice bitcode library.
600   llvm::Expected<std::unique_ptr<llvm::Module>> LibDeviceModule =
601       parseBitcodeFile(Buffer->get()->getMemBufferRef(), Module->getContext());
602   if (!LibDeviceModule)
603     POCL_ABORT("[CUDA] failed to load libdevice bitcode\n");
604 
605   // Fix triple and data-layout of libdevice module.
606   (*LibDeviceModule)->setTargetTriple(Module->getTargetTriple());
607   (*LibDeviceModule)->setDataLayout(Module->getDataLayout());
608 
609   // Link libdevice into module.
610   llvm::Linker Linker(*Module);
611   if (Linker.linkInModule(std::move(LibDeviceModule.get()))) {
612     POCL_ABORT("[CUDA] failed to link to libdevice\n");
613   }
614 
615   llvm::legacy::PassManager Passes;
616 
617   // Run internalize to mark all non-kernel functions as internal.
618   auto PreserveKernel = [=](const llvm::GlobalValue &GV) {
619     return GV.getName() == KernelName;
620   };
621   Passes.add(llvm::createInternalizePass(PreserveKernel));
622 
623   // Add NVVM reflect module flags to set math options.
624   // TODO: Determine correct FTZ value from frontend compiler options.
625   llvm::LLVMContext &Context = Module->getContext();
626   llvm::Type *I32 = llvm::Type::getInt32Ty(Context);
627   llvm::Metadata *FourMD =
628       llvm::ValueAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
629   llvm::Metadata *NameMD = llvm::MDString::get(Context, "nvvm-reflect-ftz");
630   llvm::Metadata *OneMD =
631       llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
632   llvm::MDNode *ReflectFlag =
633       llvm::MDNode::get(Context, {FourMD, NameMD, OneMD});
634   Module->addModuleFlag(ReflectFlag);
635 
636   // Run optimization passes to clean up unused functions etc.
637   llvm::PassManagerBuilder Builder;
638   Builder.OptLevel = 3;
639   Builder.SizeLevel = 0;
640   Builder.populateModulePassManager(Passes);
641 
642   Passes.run(*Module);
643 }
644 
645 // This transformation replaces each pointer argument in the specific address
646 // space with an integer offset, and then inserts the necessary GEP+BitCast
647 // instructions to calculate the new pointers from the provided base global
648 // variable.
649 void convertPtrArgsToOffsets(llvm::Module *Module, const char *KernelName,
650                              unsigned AddrSpace, llvm::GlobalVariable *Base) {
651 
652   llvm::LLVMContext &Context = Module->getContext();
653 
654   llvm::Function *Function = Module->getFunction(KernelName);
655   if (!Function)
656     POCL_ABORT("[CUDA] ptx-gen: kernel function not found in module\n");
657 
658   // Argument info for creating new function.
659   std::vector<llvm::Argument *> Arguments;
660   std::vector<llvm::Type *> ArgumentTypes;
661 
662   llvm::ValueToValueMapTy VV;
663   std::vector<std::pair<llvm::Instruction *, llvm::Instruction *>> ToInsert;
664 
665   bool NeedsArgOffsets = false;
666   for (auto &Arg : Function->args()) {
667     // Check for local memory pointer.
668     llvm::Type *ArgType = Arg.getType();
669     if (ArgType->isPointerTy() &&
670         ArgType->getPointerAddressSpace() == AddrSpace) {
671       NeedsArgOffsets = true;
672 
673       // Create new argument for offset into shared memory allocation.
674       llvm::Type *I32ty = llvm::Type::getInt32Ty(Context);
675       llvm::Argument *Offset =
676           new llvm::Argument(I32ty, Arg.getName() + "_offset");
677       Arguments.push_back(Offset);
678       ArgumentTypes.push_back(I32ty);
679 
680       // Insert GEP to add offset.
681       llvm::Value *Zero = llvm::ConstantInt::getSigned(I32ty, 0);
682       llvm::GetElementPtrInst *GEP =
683           llvm::GetElementPtrInst::Create(Base->getType()->getPointerElementType(), Base, {Zero, Offset});
684 
685       // Cast pointer to correct type.
686       llvm::BitCastInst *Cast = new llvm::BitCastInst(GEP, ArgType);
687 
688       // Save these instructions to insert into new function later.
689       ToInsert.push_back({GEP, Cast});
690 
691       // Map the old local memory argument to the result of this cast.
692       VV[&Arg] = Cast;
693     } else {
694       // No change to other arguments.
695       Arguments.push_back(&Arg);
696       ArgumentTypes.push_back(ArgType);
697     }
698   }
699 
700   if (!NeedsArgOffsets)
701     return;
702 
703   // Create new function with offsets instead of local memory pointers.
704   llvm::FunctionType *NewFunctionType =
705       llvm::FunctionType::get(Function->getReturnType(), ArgumentTypes, false);
706   llvm::Function *NewFunction = llvm::Function::Create(
707       NewFunctionType, Function->getLinkage(), Function->getName(), Module);
708   NewFunction->takeName(Function);
709 
710   // Map function arguments.
711   std::vector<llvm::Argument *>::iterator OldArg;
712   llvm::Function::arg_iterator NewArg;
713   for (OldArg = Arguments.begin(), NewArg = NewFunction->arg_begin();
714        NewArg != NewFunction->arg_end(); NewArg++, OldArg++) {
715     NewArg->takeName(*OldArg);
716     if ((*OldArg)->getParent())
717       VV[*OldArg] = &*NewArg;
718     else {
719       // Manually replace new offset arguments.
720       (*OldArg)->replaceAllUsesWith(&*NewArg);
721       delete *OldArg;
722     }
723   }
724 
725   // Clone function.
726   llvm::SmallVector<llvm::ReturnInst *, 1> RI;
727   CloneFunctionIntoAbs(NewFunction, Function, VV, RI);
728 
729   // Insert offset instructions into new function.
730   for (auto Pair : ToInsert) {
731     Pair.first->insertBefore(&*NewFunction->begin()->begin());
732     Pair.second->insertAfter(Pair.first);
733   }
734 
735   Function->eraseFromParent();
736 }
737 
738 // CUDA doesn't allow constant pointer arguments, so we have to convert them to
739 // offsets and manually add them to a global variable base pointer.
740 void fixConstantMemArgs(llvm::Module *Module, const char *KernelName) {
741 
742   // Calculate total size of automatic constant allocations.
743   size_t TotalAutoConstantSize = 0;
744   for (auto &GlobalVar : Module->globals()) {
745     if (GlobalVar.getType()->getPointerAddressSpace() == 4)
746       TotalAutoConstantSize += Module->getDataLayout().getTypeAllocSize(
747           GlobalVar.getInitializer()->getType());
748   }
749 
750   // Create global variable for constant memory allocations.
751   // TODO: Does allocating the maximum amount have a penalty?
752   llvm::Type *ByteArrayType =
753       llvm::ArrayType::get(llvm::Type::getInt8Ty(Module->getContext()),
754                            65536 - TotalAutoConstantSize);
755   llvm::GlobalVariable *ConstantMemBase = new llvm::GlobalVariable(
756       *Module, ByteArrayType, false, llvm::GlobalValue::InternalLinkage,
757       llvm::Constant::getNullValue(ByteArrayType), "_constant_memory_region_",
758       NULL, llvm::GlobalValue::NotThreadLocal, 4, false);
759 
760   convertPtrArgsToOffsets(Module, KernelName, 4, ConstantMemBase);
761 }
762 
763 // CUDA doesn't allow multiple local memory arguments or automatic variables, so
764 // we have to create a single global variable for local memory allocations, and
765 // then manually add offsets to it to get each individual local memory
766 // allocation.
767 void fixLocalMemArgs(llvm::Module *Module, const char *KernelName) {
768 
769   // Create global variable for local memory allocations.
770   llvm::Type *ByteArrayType =
771       llvm::ArrayType::get(llvm::Type::getInt8Ty(Module->getContext()), 0);
772   llvm::GlobalVariable *SharedMemBase = new llvm::GlobalVariable(
773       *Module, ByteArrayType, false, llvm::GlobalValue::ExternalLinkage, NULL,
774       "_shared_memory_region_", NULL, llvm::GlobalValue::NotThreadLocal, 3,
775       false);
776 
777   convertPtrArgsToOffsets(Module, KernelName, 3, SharedMemBase);
778 }
779 
780 // Map kernel math functions onto the corresponding CUDA libdevice functions.
781 void mapLibDeviceCalls(llvm::Module *Module) {
782   struct FunctionMapEntry {
783     const char *OCLFunctionName;
784     const char *LibDeviceFunctionName;
785   };
786   struct FunctionMapEntry FunctionMap[] = {
787 
788 // clang-format off
789 #define LDMAP(name) \
790   {        name "f",    "__nv_" name "f"}, \
791   {        name,        "__nv_" name}, \
792   {"llvm." name ".f32", "__nv_" name "f"}, \
793   {"llvm." name ".f64", "__nv_" name},
794 
795     LDMAP("acos")
796     LDMAP("acosh")
797     LDMAP("asin")
798     LDMAP("asinh")
799     LDMAP("atan")
800     LDMAP("atanh")
801     LDMAP("atan2")
802     LDMAP("cbrt")
803     LDMAP("ceil")
804     LDMAP("copysign")
805     LDMAP("cos")
806     LDMAP("cosh")
807     LDMAP("exp")
808     LDMAP("exp2")
809     LDMAP("expm1")
810     LDMAP("fdim")
811     LDMAP("floor")
812     LDMAP("fmax")
813     LDMAP("fmin")
814     LDMAP("hypot")
815     LDMAP("ilogb")
816     LDMAP("lgamma")
817     LDMAP("log")
818     LDMAP("log2")
819     LDMAP("log10")
820     LDMAP("log1p")
821     LDMAP("logb")
822     LDMAP("nextafter")
823     LDMAP("remainder")
824     LDMAP("rint")
825     LDMAP("round")
826     LDMAP("sin")
827     LDMAP("sinh")
828     LDMAP("sqrt")
829     LDMAP("tan")
830     LDMAP("tanh")
831     LDMAP("trunc")
832 #undef LDMAP
833 
834     {"llvm.copysign.f32", "__nv_copysignf"},
835     {"llvm.copysign.f64", "__nv_copysign"},
836 
837     {"llvm.pow.f32", "__nv_powf"},
838     {"llvm.pow.f64", "__nv_pow"},
839 
840     {"llvm.powi.f32", "__nv_powif"},
841     {"llvm.powi.f64", "__nv_powi"},
842 
843     {"frexp", "__nv_frexp"},
844     {"frexpf", "__nv_frexpf"},
845 
846     {"tgamma", "__nv_tgamma"},
847     {"tgammaf", "__nv_tgammaf"},
848 
849     {"ldexp", "__nv_ldexp"},
850     {"ldexpf", "__nv_ldexpf"},
851 
852     {"modf", "__nv_modf"},
853     {"modff", "__nv_modff"},
854 
855     {"remquo", "__nv_remquo"},
856     {"remquof", "__nv_remquof"},
857     // TODO: lgamma_r
858     // TODO: rootn
859   };
860   // clang-format on
861 
862   for (auto &Entry : FunctionMap) {
863     llvm::Function *Function = Module->getFunction(Entry.OCLFunctionName);
864     if (!Function)
865       continue;
866 
867     std::vector<llvm::Value *> Users(Function->user_begin(),
868                                      Function->user_end());
869     for (auto &U : Users) {
870       // Look for calls to function.
871       llvm::CallInst *Call = llvm::dyn_cast<llvm::CallInst>(U);
872       if (Call) {
873         // Create function declaration for libdevice version.
874         llvm::FunctionType *FunctionType = Function->getFunctionType();
875 #ifdef LLVM_OLDER_THAN_9_0
876         llvm::Constant *LibDeviceFunction = Module->getOrInsertFunction(
877             Entry.LibDeviceFunctionName, FunctionType);
878 #else
879         llvm::FunctionCallee FC = Module->getOrInsertFunction(
880             Entry.LibDeviceFunctionName, FunctionType);
881         llvm::Function *LibDeviceFunction = llvm::cast<llvm::Function>(FC.getCallee());
882 #endif
883         // Replace function with libdevice version.
884         std::vector<llvm::Value *> Args(Call->arg_begin(), Call->arg_end());
885         llvm::CallInst *NewCall =
886             llvm::CallInst::Create(LibDeviceFunction, Args, "", Call);
887         NewCall->takeName(Call);
888         Call->replaceAllUsesWith(NewCall);
889         Call->eraseFromParent();
890       }
891     }
892 
893     Function->eraseFromParent();
894   }
895 }
896 
897 int pocl_cuda_get_ptr_arg_alignment(const char *BitcodeFilename,
898                                     const char *KernelName,
899                                     size_t *Alignments) {
900   // Create buffer for bitcode file.
901   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> Buffer =
902       llvm::MemoryBuffer::getFile(BitcodeFilename);
903   if (!Buffer) {
904     POCL_MSG_ERR("[CUDA] ptx-gen: failed to open bitcode file\n");
905     return 1;
906   }
907 
908   // Load the LLVM bitcode module.
909   llvm::LLVMContext Context;
910   llvm::Expected<std::unique_ptr<llvm::Module>> Module =
911       parseBitcodeFile(Buffer->get()->getMemBufferRef(), Context);
912   if (!Module) {
913     POCL_MSG_ERR("[CUDA] ptx-gen: failed to load bitcode\n");
914     return 1;
915   }
916 
917   // Get kernel function.
918   llvm::Function *Kernel = (*Module)->getFunction(KernelName);
919   if (!Kernel)
920     POCL_ABORT("[CUDA] kernel function not found in module\n");
921 
922   // Calculate alignment for each argument.
923   const llvm::DataLayout &DL = (*Module)->getDataLayout();
924   for (auto &Arg : Kernel->args()) {
925     unsigned i = Arg.getArgNo();
926     llvm::Type *Type = Arg.getType();
927     if (!Type->isPointerTy())
928       Alignments[i] = 0;
929     else {
930       llvm::Type *ElemType = Type->getPointerElementType();
931       Alignments[i] = DL.getTypeAllocSize(ElemType);
932     }
933   }
934 
935   return 0;
936 }
937