1 //===- SPIRVUtil.cpp - SPIR-V Utilities -------------------------*- C++ -*-===//
2 //
3 //                     The LLVM/SPIRV Translator
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining a
11 // copy of this software and associated documentation files (the "Software"),
12 // to deal with the Software without restriction, including without limitation
13 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
14 // and/or sell copies of the Software, and to permit persons to whom the
15 // Software is furnished to do so, subject to the following conditions:
16 //
17 // Redistributions of source code must retain the above copyright notice,
18 // this list of conditions and the following disclaimers.
19 // Redistributions in binary form must reproduce the above copyright notice,
20 // this list of conditions and the following disclaimers in the documentation
21 // and/or other materials provided with the distribution.
22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its
23 // contributors may be used to endorse or promote products derived from this
24 // Software without specific prior written permission.
25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31 // THE SOFTWARE.
32 //
33 //===----------------------------------------------------------------------===//
34 /// \file
35 ///
36 /// This file defines utility classes and functions shared by SPIR-V
37 /// reader/writer.
38 ///
39 //===----------------------------------------------------------------------===//
40 
41 #include "FunctionDescriptor.h"
42 #include "ManglingUtils.h"
43 #include "NameMangleAPI.h"
44 #include "OCLUtil.h"
45 #include "ParameterType.h"
46 #include "SPIRVInternal.h"
47 #include "SPIRVMDWalker.h"
48 #include "libSPIRV/SPIRVDecorate.h"
49 #include "libSPIRV/SPIRVValue.h"
50 
51 #include "llvm/ADT/StringSwitch.h"
52 #include "llvm/Bitcode/BitcodeWriter.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/IntrinsicInst.h"
55 #include "llvm/Support/CommandLine.h"
56 #include "llvm/Support/Debug.h"
57 #include "llvm/Support/ErrorHandling.h"
58 #include "llvm/Support/FileSystem.h"
59 #include "llvm/Support/ToolOutputFile.h"
60 
61 #include <functional>
62 #include <sstream>
63 
64 #define DEBUG_TYPE "spirv"
65 
66 namespace SPIRV {
67 
68 #ifdef _SPIRV_SUPPORT_TEXT_FMT
69 cl::opt<bool, true>
70     UseTextFormat("spirv-text",
71                   cl::desc("Use text format for SPIR-V for debugging purpose"),
72                   cl::location(SPIRVUseTextFormat));
73 #endif
74 
75 #ifdef _SPIRVDBG
76 cl::opt<bool, true> EnableDbgOutput("spirv-debug",
77                                     cl::desc("Enable SPIR-V debug output"),
78                                     cl::location(SPIRVDbgEnable));
79 #endif
80 
isSupportedTriple(Triple T)81 bool isSupportedTriple(Triple T) { return T.isSPIR(); }
82 
addFnAttr(CallInst * Call,Attribute::AttrKind Attr)83 void addFnAttr(CallInst *Call, Attribute::AttrKind Attr) {
84   Call->addAttribute(AttributeList::FunctionIndex, Attr);
85 }
86 
removeFnAttr(CallInst * Call,Attribute::AttrKind Attr)87 void removeFnAttr(CallInst *Call, Attribute::AttrKind Attr) {
88   Call->removeAttribute(AttributeList::FunctionIndex, Attr);
89 }
90 
removeCast(Value * V)91 Value *removeCast(Value *V) {
92   auto Cast = dyn_cast<ConstantExpr>(V);
93   if (Cast && Cast->isCast()) {
94     return removeCast(Cast->getOperand(0));
95   }
96   if (auto Cast = dyn_cast<CastInst>(V))
97     return removeCast(Cast->getOperand(0));
98   return V;
99 }
100 
saveLLVMModule(Module * M,const std::string & OutputFile)101 void saveLLVMModule(Module *M, const std::string &OutputFile) {
102   std::error_code EC;
103   ToolOutputFile Out(OutputFile.c_str(), EC, sys::fs::OF_None);
104   if (EC) {
105     SPIRVDBG(errs() << "Fails to open output file: " << EC.message();)
106     return;
107   }
108 
109   WriteBitcodeToFile(*M, Out.os());
110   Out.keep();
111 }
112 
mapLLVMTypeToOCLType(const Type * Ty,bool Signed)113 std::string mapLLVMTypeToOCLType(const Type *Ty, bool Signed) {
114   if (Ty->isHalfTy())
115     return "half";
116   if (Ty->isFloatTy())
117     return "float";
118   if (Ty->isDoubleTy())
119     return "double";
120   if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
121     std::string SignPrefix;
122     std::string Stem;
123     if (!Signed)
124       SignPrefix = "u";
125     switch (IntTy->getIntegerBitWidth()) {
126     case 8:
127       Stem = "char";
128       break;
129     case 16:
130       Stem = "short";
131       break;
132     case 32:
133       Stem = "int";
134       break;
135     case 64:
136       Stem = "long";
137       break;
138     default:
139       Stem = "invalid_type";
140       break;
141     }
142     return SignPrefix + Stem;
143   }
144   if (auto VecTy = dyn_cast<FixedVectorType>(Ty)) {
145     Type *EleTy = VecTy->getElementType();
146     unsigned Size = VecTy->getNumElements();
147     std::stringstream Ss;
148     Ss << mapLLVMTypeToOCLType(EleTy, Signed) << Size;
149     return Ss.str();
150   }
151   return "invalid_type";
152 }
153 
mapSPIRVTypeToOCLType(SPIRVType * Ty,bool Signed)154 std::string mapSPIRVTypeToOCLType(SPIRVType *Ty, bool Signed) {
155   if (Ty->isTypeFloat()) {
156     auto W = Ty->getBitWidth();
157     switch (W) {
158     case 16:
159       return "half";
160     case 32:
161       return "float";
162     case 64:
163       return "double";
164     default:
165       assert(0 && "Invalid floating pointer type");
166       return std::string("float") + W + "_t";
167     }
168   }
169   if (Ty->isTypeInt()) {
170     std::string SignPrefix;
171     std::string Stem;
172     if (!Signed)
173       SignPrefix = "u";
174     auto W = Ty->getBitWidth();
175     switch (W) {
176     case 8:
177       Stem = "char";
178       break;
179     case 16:
180       Stem = "short";
181       break;
182     case 32:
183       Stem = "int";
184       break;
185     case 64:
186       Stem = "long";
187       break;
188     default:
189       llvm_unreachable("Invalid integer type");
190       Stem = std::string("int") + W + "_t";
191       break;
192     }
193     return SignPrefix + Stem;
194   }
195   if (Ty->isTypeVector()) {
196     auto EleTy = Ty->getVectorComponentType();
197     auto Size = Ty->getVectorComponentCount();
198     std::stringstream Ss;
199     Ss << mapSPIRVTypeToOCLType(EleTy, Signed) << Size;
200     return Ss.str();
201   }
202   llvm_unreachable("Invalid type");
203   return "unknown_type";
204 }
205 
getOrCreateOpaquePtrType(Module * M,const std::string & Name,unsigned AddrSpace)206 PointerType *getOrCreateOpaquePtrType(Module *M, const std::string &Name,
207                                       unsigned AddrSpace) {
208   auto OpaqueType = StructType::getTypeByName(M->getContext(), Name);
209   if (!OpaqueType)
210     OpaqueType = StructType::create(M->getContext(), Name);
211   return PointerType::get(OpaqueType, AddrSpace);
212 }
213 
getSamplerType(Module * M)214 PointerType *getSamplerType(Module *M) {
215   return getOrCreateOpaquePtrType(M, getSPIRVTypeName(kSPIRVTypeName::Sampler),
216                                   SPIRAS_Constant);
217 }
218 
getPipeStorageType(Module * M)219 PointerType *getPipeStorageType(Module *M) {
220   return getOrCreateOpaquePtrType(
221       M, getSPIRVTypeName(kSPIRVTypeName::PipeStorage), SPIRAS_Constant);
222 }
223 
getSPIRVOpaquePtrType(Module * M,Op OC)224 PointerType *getSPIRVOpaquePtrType(Module *M, Op OC) {
225   std::string Name = getSPIRVTypeName(SPIRVOpaqueTypeOpCodeMap::rmap(OC));
226   return getOrCreateOpaquePtrType(M, Name, getOCLOpaqueTypeAddrSpace(OC));
227 }
228 
getFunctionTypeParameterTypes(llvm::FunctionType * FT,std::vector<Type * > & ArgTys)229 void getFunctionTypeParameterTypes(llvm::FunctionType *FT,
230                                    std::vector<Type *> &ArgTys) {
231   for (auto I = FT->param_begin(), E = FT->param_end(); I != E; ++I) {
232     ArgTys.push_back(*I);
233   }
234 }
235 
isVoidFuncTy(FunctionType * FT)236 bool isVoidFuncTy(FunctionType *FT) {
237   return FT->getReturnType()->isVoidTy() && FT->getNumParams() == 0;
238 }
239 
isPointerToOpaqueStructType(llvm::Type * Ty)240 bool isPointerToOpaqueStructType(llvm::Type *Ty) {
241   if (auto PT = dyn_cast<PointerType>(Ty))
242     if (auto ST = dyn_cast<StructType>(PT->getElementType()))
243       if (ST->isOpaque())
244         return true;
245   return false;
246 }
247 
isPointerToOpaqueStructType(llvm::Type * Ty,const std::string & Name)248 bool isPointerToOpaqueStructType(llvm::Type *Ty, const std::string &Name) {
249   if (auto PT = dyn_cast<PointerType>(Ty))
250     if (auto ST = dyn_cast<StructType>(PT->getElementType()))
251       if (ST->isOpaque() && ST->getName() == Name)
252         return true;
253   return false;
254 }
255 
isOCLImageType(llvm::Type * Ty,StringRef * Name)256 bool isOCLImageType(llvm::Type *Ty, StringRef *Name) {
257   if (auto PT = dyn_cast<PointerType>(Ty))
258     if (auto ST = dyn_cast<StructType>(PT->getElementType()))
259       if (ST->isOpaque()) {
260         auto FullName = ST->getName();
261         if (FullName.find(kSPR2TypeName::ImagePrefix) == 0) {
262           if (Name)
263             *Name = FullName.drop_front(strlen(kSPR2TypeName::OCLPrefix));
264           return true;
265         }
266       }
267   return false;
268 }
269 
270 /// \param BaseTyName is the type Name as in spirv.BaseTyName.Postfixes
271 /// \param Postfix contains postfixes extracted from the SPIR-V image
272 ///   type Name as spirv.BaseTyName.Postfixes.
isSPIRVType(llvm::Type * Ty,StringRef BaseTyName,StringRef * Postfix)273 bool isSPIRVType(llvm::Type *Ty, StringRef BaseTyName, StringRef *Postfix) {
274   if (auto PT = dyn_cast<PointerType>(Ty))
275     if (auto ST = dyn_cast<StructType>(PT->getElementType()))
276       if (ST->isOpaque()) {
277         auto FullName = ST->getName();
278         std::string N =
279             std::string(kSPIRVTypeName::PrefixAndDelim) + BaseTyName.str();
280         if (FullName != N)
281           N = N + kSPIRVTypeName::Delimiter;
282         if (FullName.startswith(N)) {
283           if (Postfix)
284             *Postfix = FullName.drop_front(N.size());
285           return true;
286         }
287       }
288   return false;
289 }
290 
getOrCreateFunction(Module * M,Type * RetTy,ArrayRef<Type * > ArgTypes,StringRef Name,BuiltinFuncMangleInfo * Mangle,AttributeList * Attrs,bool TakeName)291 Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
292                               StringRef Name, BuiltinFuncMangleInfo *Mangle,
293                               AttributeList *Attrs, bool TakeName) {
294   std::string MangledName{Name};
295   bool IsVarArg = false;
296   if (Mangle) {
297     MangledName = mangleBuiltin(Name, ArgTypes, Mangle);
298     IsVarArg = 0 <= Mangle->getVarArg();
299     if (IsVarArg)
300       ArgTypes = ArgTypes.slice(0, Mangle->getVarArg());
301   }
302   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, IsVarArg);
303   Function *F = M->getFunction(MangledName);
304   if (!TakeName && F && F->getFunctionType() != FT && Mangle != nullptr) {
305     std::string S;
306     raw_string_ostream SS(S);
307     SS << "Error: Attempt to redefine function: " << *F << " => " << *FT
308        << '\n';
309     report_fatal_error(SS.str(), false);
310   }
311   if (!F || F->getFunctionType() != FT) {
312     auto NewF =
313         Function::Create(FT, GlobalValue::ExternalLinkage, MangledName, M);
314     if (F && TakeName) {
315       NewF->takeName(F);
316       LLVM_DEBUG(
317           dbgs() << "[getOrCreateFunction] Warning: taking function Name\n");
318     }
319     if (NewF->getName() != MangledName) {
320       LLVM_DEBUG(
321           dbgs() << "[getOrCreateFunction] Warning: function Name changed\n");
322     }
323     LLVM_DEBUG(dbgs() << "[getOrCreateFunction] ";
324                if (F) dbgs() << *F << " => "; dbgs() << *NewF << '\n';);
325     F = NewF;
326     F->setCallingConv(CallingConv::SPIR_FUNC);
327     if (Attrs)
328       F->setAttributes(*Attrs);
329   }
330   return F;
331 }
332 
getArguments(CallInst * CI,unsigned Start,unsigned End)333 std::vector<Value *> getArguments(CallInst *CI, unsigned Start, unsigned End) {
334   std::vector<Value *> Args;
335   if (End == 0)
336     End = CI->getNumArgOperands();
337   for (; Start != End; ++Start) {
338     Args.push_back(CI->getArgOperand(Start));
339   }
340   return Args;
341 }
342 
getArgAsInt(CallInst * CI,unsigned I)343 uint64_t getArgAsInt(CallInst *CI, unsigned I) {
344   return cast<ConstantInt>(CI->getArgOperand(I))->getZExtValue();
345 }
346 
getArgAsScope(CallInst * CI,unsigned I)347 Scope getArgAsScope(CallInst *CI, unsigned I) {
348   return static_cast<Scope>(getArgAsInt(CI, I));
349 }
350 
getArgAsDecoration(CallInst * CI,unsigned I)351 Decoration getArgAsDecoration(CallInst *CI, unsigned I) {
352   return static_cast<Decoration>(getArgAsInt(CI, I));
353 }
354 
decorateSPIRVFunction(const std::string & S)355 std::string decorateSPIRVFunction(const std::string &S) {
356   return std::string(kSPIRVName::Prefix) + S + kSPIRVName::Postfix;
357 }
358 
undecorateSPIRVFunction(StringRef S)359 StringRef undecorateSPIRVFunction(StringRef S) {
360   assert(S.find(kSPIRVName::Prefix) == 0);
361   const size_t Start = strlen(kSPIRVName::Prefix);
362   auto End = S.rfind(kSPIRVName::Postfix);
363   return S.substr(Start, End - Start);
364 }
365 
prefixSPIRVName(const std::string & S)366 std::string prefixSPIRVName(const std::string &S) {
367   return std::string(kSPIRVName::Prefix) + S;
368 }
369 
dePrefixSPIRVName(StringRef R,SmallVectorImpl<StringRef> & Postfix)370 StringRef dePrefixSPIRVName(StringRef R, SmallVectorImpl<StringRef> &Postfix) {
371   const size_t Start = strlen(kSPIRVName::Prefix);
372   if (!R.startswith(kSPIRVName::Prefix))
373     return R;
374   R = R.drop_front(Start);
375   R.split(Postfix, "_", -1, false);
376   auto Name = Postfix.front();
377   Postfix.erase(Postfix.begin());
378   return Name;
379 }
380 
getSPIRVFuncName(Op OC,StringRef PostFix)381 std::string getSPIRVFuncName(Op OC, StringRef PostFix) {
382   return prefixSPIRVName(getName(OC) + PostFix.str());
383 }
384 
getSPIRVFuncName(Op OC,const Type * PRetTy,bool IsSigned)385 std::string getSPIRVFuncName(Op OC, const Type *PRetTy, bool IsSigned) {
386   return prefixSPIRVName(getName(OC) + kSPIRVPostfix::Divider +
387                          getPostfixForReturnType(PRetTy, IsSigned));
388 }
389 
getSPIRVFuncName(SPIRVBuiltinVariableKind BVKind)390 std::string getSPIRVFuncName(SPIRVBuiltinVariableKind BVKind) {
391   return prefixSPIRVName(getName(BVKind));
392 }
393 
getSPIRVExtFuncName(SPIRVExtInstSetKind Set,unsigned ExtOp,StringRef PostFix)394 std::string getSPIRVExtFuncName(SPIRVExtInstSetKind Set, unsigned ExtOp,
395                                 StringRef PostFix) {
396   std::string ExtOpName;
397   switch (Set) {
398   default:
399     llvm_unreachable("invalid extended instruction set");
400     ExtOpName = "unknown";
401     break;
402   case SPIRVEIS_OpenCL:
403     ExtOpName = getName(static_cast<OCLExtOpKind>(ExtOp));
404     break;
405   }
406   return prefixSPIRVName(SPIRVExtSetShortNameMap::map(Set) + '_' + ExtOpName +
407                          PostFix.str());
408 }
409 
mapPostfixToDecorate(StringRef Postfix,SPIRVEntry * Target)410 SPIRVDecorate *mapPostfixToDecorate(StringRef Postfix, SPIRVEntry *Target) {
411   if (Postfix == kSPIRVPostfix::Sat)
412     return new SPIRVDecorate(spv::DecorationSaturatedConversion, Target);
413 
414   if (Postfix.startswith(kSPIRVPostfix::Rt))
415     return new SPIRVDecorate(spv::DecorationFPRoundingMode, Target,
416                              map<SPIRVFPRoundingModeKind>(Postfix.str()));
417 
418   return nullptr;
419 }
420 
addDecorations(SPIRVValue * Target,const SmallVectorImpl<std::string> & Decs)421 SPIRVValue *addDecorations(SPIRVValue *Target,
422                            const SmallVectorImpl<std::string> &Decs) {
423   for (auto &I : Decs)
424     if (auto Dec = mapPostfixToDecorate(I, Target))
425       Target->addDecorate(Dec);
426   return Target;
427 }
428 
getPostfix(Decoration Dec,unsigned Value)429 std::string getPostfix(Decoration Dec, unsigned Value) {
430   switch (Dec) {
431   default:
432     llvm_unreachable("not implemented");
433     return "unknown";
434   case spv::DecorationSaturatedConversion:
435     return kSPIRVPostfix::Sat;
436   case spv::DecorationFPRoundingMode:
437     return rmap<std::string>(static_cast<SPIRVFPRoundingModeKind>(Value));
438   }
439 }
440 
getPostfixForReturnType(CallInst * CI,bool IsSigned)441 std::string getPostfixForReturnType(CallInst *CI, bool IsSigned) {
442   return getPostfixForReturnType(CI->getType(), IsSigned);
443 }
444 
getPostfixForReturnType(const Type * PRetTy,bool IsSigned)445 std::string getPostfixForReturnType(const Type *PRetTy, bool IsSigned) {
446   return std::string(kSPIRVPostfix::Return) +
447          mapLLVMTypeToOCLType(PRetTy, IsSigned);
448 }
449 
450 // Enqueue kernel, kernel query, pipe and address space cast built-ins
451 // are not mangled.
isNonMangledOCLBuiltin(StringRef Name)452 bool isNonMangledOCLBuiltin(StringRef Name) {
453   if (!Name.startswith("__"))
454     return false;
455 
456   return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) ||
457          isPipeOrAddressSpaceCastBI(Name.drop_front(2));
458 }
459 
getSPIRVFuncOC(StringRef S,SmallVectorImpl<std::string> * Dec)460 Op getSPIRVFuncOC(StringRef S, SmallVectorImpl<std::string> *Dec) {
461   Op OC;
462   SmallVector<StringRef, 2> Postfix;
463   StringRef Name;
464   if (!oclIsBuiltin(S, Name))
465     Name = S;
466   StringRef R(Name);
467   if ((!Name.startswith(kSPIRVName::Prefix) && !isNonMangledOCLBuiltin(S)) ||
468       !getByName(dePrefixSPIRVName(R, Postfix).str(), OC)) {
469     return OpNop;
470   }
471   if (Dec)
472     for (auto &I : Postfix)
473       Dec->push_back(I.str());
474   return OC;
475 }
476 
getSPIRVBuiltin(const std::string & OrigName,spv::BuiltIn & B)477 bool getSPIRVBuiltin(const std::string &OrigName, spv::BuiltIn &B) {
478   SmallVector<StringRef, 2> Postfix;
479   StringRef R(OrigName);
480   R = dePrefixSPIRVName(R, Postfix);
481   if (!Postfix.empty())
482     return false;
483   return getByName(R.str(), B);
484 }
485 
486 // Demangled name is a substring of the name. The DemangledName is updated only
487 // if true is returned
oclIsBuiltin(StringRef Name,StringRef & DemangledName,bool IsCpp)488 bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
489   if (Name == "printf") {
490     DemangledName = Name;
491     return true;
492   }
493   if (isNonMangledOCLBuiltin(Name)) {
494     DemangledName = Name.drop_front(2);
495     return true;
496   }
497   if (!Name.startswith("_Z"))
498     return false;
499   // OpenCL C++ built-ins are declared in cl namespace.
500   // TODO: consider using 'St' abbriviation for cl namespace mangling.
501   // Similar to ::std:: in C++.
502   if (IsCpp) {
503     if (!Name.startswith("_ZN"))
504       return false;
505     // Skip CV and ref qualifiers.
506     size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3);
507     // All built-ins are in the ::cl:: namespace.
508     if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv")
509       return false;
510     size_t DemangledNameLenStart = NameSpaceStart + 11;
511     size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
512     size_t Len = 0;
513     Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart)
514         .getAsInteger(10, Len);
515     DemangledName = Name.substr(Start, Len);
516   } else {
517     size_t Start = Name.find_first_not_of("0123456789", 2);
518     size_t Len = 0;
519     Name.substr(2, Start - 2).getAsInteger(10, Len);
520     DemangledName = Name.substr(Start, Len);
521   }
522   return true;
523 }
524 
525 // Check if a mangled type Name is unsigned
isMangledTypeUnsigned(char Mangled)526 bool isMangledTypeUnsigned(char Mangled) {
527   return Mangled == 'h'    /* uchar */
528          || Mangled == 't' /* ushort */
529          || Mangled == 'j' /* uint */
530          || Mangled == 'm' /* ulong */;
531 }
532 
533 // Check if a mangled type Name is signed
isMangledTypeSigned(char Mangled)534 bool isMangledTypeSigned(char Mangled) {
535   return Mangled == 'c'    /* char */
536          || Mangled == 'a' /* signed char */
537          || Mangled == 's' /* short */
538          || Mangled == 'i' /* int */
539          || Mangled == 'l' /* long */;
540 }
541 
542 // Check if a mangled type Name is floating point (excludes half)
isMangledTypeFP(char Mangled)543 bool isMangledTypeFP(char Mangled) {
544   return Mangled == 'f'     /* float */
545          || Mangled == 'd'; /* double */
546 }
547 
548 // Check if a mangled type Name is half
isMangledTypeHalf(std::string Mangled)549 bool isMangledTypeHalf(std::string Mangled) {
550   return Mangled == "Dh"; /* half */
551 }
552 
eraseSubstitutionFromMangledName(std::string & MangledName)553 void eraseSubstitutionFromMangledName(std::string &MangledName) {
554   auto Len = MangledName.length();
555   while (Len >= 2 && MangledName.substr(Len - 2, 2) == "S_") {
556     Len -= 2;
557     MangledName.erase(Len, 2);
558   }
559 }
560 
lastFuncParamType(StringRef MangledName)561 ParamType lastFuncParamType(StringRef MangledName) {
562   std::string Copy(MangledName);
563   eraseSubstitutionFromMangledName(Copy);
564   char Mangled = Copy.back();
565   std::string Mangled2 = Copy.substr(Copy.size() - 2);
566 
567   if (isMangledTypeFP(Mangled) || isMangledTypeHalf(Mangled2)) {
568     return ParamType::FLOAT;
569   } else if (isMangledTypeUnsigned(Mangled)) {
570     return ParamType::UNSIGNED;
571   } else if (isMangledTypeSigned(Mangled)) {
572     return ParamType::SIGNED;
573   }
574 
575   return ParamType::UNKNOWN;
576 }
577 
578 // Check if the last argument is signed
isLastFuncParamSigned(StringRef MangledName)579 bool isLastFuncParamSigned(StringRef MangledName) {
580   return lastFuncParamType(MangledName) == ParamType::SIGNED;
581 }
582 
583 // Check if a mangled function Name contains unsigned atomic type
containsUnsignedAtomicType(StringRef Name)584 bool containsUnsignedAtomicType(StringRef Name) {
585   auto Loc = Name.find(kMangledName::AtomicPrefixIncoming);
586   if (Loc == StringRef::npos)
587     return false;
588   return isMangledTypeUnsigned(
589       Name[Loc + strlen(kMangledName::AtomicPrefixIncoming)]);
590 }
591 
isFunctionPointerType(Type * T)592 bool isFunctionPointerType(Type *T) {
593   if (isa<PointerType>(T) && isa<FunctionType>(T->getPointerElementType())) {
594     return true;
595   }
596   return false;
597 }
598 
hasFunctionPointerArg(Function * F,Function::arg_iterator & AI)599 bool hasFunctionPointerArg(Function *F, Function::arg_iterator &AI) {
600   AI = F->arg_begin();
601   for (auto AE = F->arg_end(); AI != AE; ++AI) {
602     LLVM_DEBUG(dbgs() << "[hasFuncPointerArg] " << *AI << '\n');
603     if (isFunctionPointerType(AI->getType())) {
604       return true;
605     }
606   }
607   return false;
608 }
609 
castToVoidFuncPtr(Function * F)610 Constant *castToVoidFuncPtr(Function *F) {
611   auto T = getVoidFuncPtrType(F->getParent());
612   return ConstantExpr::getBitCast(F, T);
613 }
614 
hasArrayArg(Function * F)615 bool hasArrayArg(Function *F) {
616   for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
617     LLVM_DEBUG(dbgs() << "[hasArrayArg] " << *I << '\n');
618     if (I->getType()->isArrayTy()) {
619       return true;
620     }
621   }
622   return false;
623 }
624 
mutateCallInst(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &)> ArgMutate,BuiltinFuncMangleInfo * Mangle,AttributeList * Attrs,bool TakeFuncName)625 CallInst *mutateCallInst(
626     Module *M, CallInst *CI,
627     std::function<std::string(CallInst *, std::vector<Value *> &)> ArgMutate,
628     BuiltinFuncMangleInfo *Mangle, AttributeList *Attrs, bool TakeFuncName) {
629   LLVM_DEBUG(dbgs() << "[mutateCallInst] " << *CI);
630 
631   auto Args = getArguments(CI);
632   auto NewName = ArgMutate(CI, Args);
633   std::string InstName;
634   if (!CI->getType()->isVoidTy() && CI->hasName()) {
635     InstName = CI->getName().str();
636     CI->setName(InstName + ".old");
637   }
638   auto NewCI = addCallInst(M, NewName, CI->getType(), Args, Attrs, CI, Mangle,
639                            InstName, TakeFuncName);
640   NewCI->setDebugLoc(CI->getDebugLoc());
641   LLVM_DEBUG(dbgs() << " => " << *NewCI << '\n');
642   CI->replaceAllUsesWith(NewCI);
643   CI->eraseFromParent();
644   return NewCI;
645 }
646 
mutateCallInst(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &,Type * & RetTy)> ArgMutate,std::function<Instruction * (CallInst *)> RetMutate,BuiltinFuncMangleInfo * Mangle,AttributeList * Attrs,bool TakeFuncName)647 Instruction *mutateCallInst(
648     Module *M, CallInst *CI,
649     std::function<std::string(CallInst *, std::vector<Value *> &, Type *&RetTy)>
650         ArgMutate,
651     std::function<Instruction *(CallInst *)> RetMutate,
652     BuiltinFuncMangleInfo *Mangle, AttributeList *Attrs, bool TakeFuncName) {
653   LLVM_DEBUG(dbgs() << "[mutateCallInst] " << *CI);
654 
655   auto Args = getArguments(CI);
656   Type *RetTy = CI->getType();
657   auto NewName = ArgMutate(CI, Args, RetTy);
658   StringRef InstName = CI->getName();
659   auto NewCI = addCallInst(M, NewName, RetTy, Args, Attrs, CI, Mangle, InstName,
660                            TakeFuncName);
661   auto NewI = RetMutate(NewCI);
662   NewI->takeName(CI);
663   NewI->setDebugLoc(CI->getDebugLoc());
664   LLVM_DEBUG(dbgs() << " => " << *NewI << '\n');
665   if (!CI->getType()->isVoidTy())
666     CI->replaceAllUsesWith(NewI);
667   CI->eraseFromParent();
668   return NewI;
669 }
670 
mutateFunction(Function * F,std::function<std::string (CallInst *,std::vector<Value * > &)> ArgMutate,BuiltinFuncMangleInfo * Mangle,AttributeList * Attrs,bool TakeFuncName)671 void mutateFunction(
672     Function *F,
673     std::function<std::string(CallInst *, std::vector<Value *> &)> ArgMutate,
674     BuiltinFuncMangleInfo *Mangle, AttributeList *Attrs, bool TakeFuncName) {
675   auto M = F->getParent();
676   for (auto I = F->user_begin(), E = F->user_end(); I != E;) {
677     if (auto CI = dyn_cast<CallInst>(*I++))
678       mutateCallInst(M, CI, ArgMutate, Mangle, Attrs, TakeFuncName);
679   }
680   if (F->use_empty())
681     F->eraseFromParent();
682 }
683 
mutateCallInstSPIRV(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &)> ArgMutate,AttributeList * Attrs)684 CallInst *mutateCallInstSPIRV(
685     Module *M, CallInst *CI,
686     std::function<std::string(CallInst *, std::vector<Value *> &)> ArgMutate,
687     AttributeList *Attrs) {
688   BuiltinFuncMangleInfo BtnInfo;
689   return mutateCallInst(M, CI, ArgMutate, &BtnInfo, Attrs);
690 }
691 
mutateCallInstSPIRV(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &,Type * & RetTy)> ArgMutate,std::function<Instruction * (CallInst *)> RetMutate,AttributeList * Attrs)692 Instruction *mutateCallInstSPIRV(
693     Module *M, CallInst *CI,
694     std::function<std::string(CallInst *, std::vector<Value *> &, Type *&RetTy)>
695         ArgMutate,
696     std::function<Instruction *(CallInst *)> RetMutate, AttributeList *Attrs) {
697   BuiltinFuncMangleInfo BtnInfo;
698   return mutateCallInst(M, CI, ArgMutate, RetMutate, &BtnInfo, Attrs);
699 }
700 
addCallInst(Module * M,StringRef FuncName,Type * RetTy,ArrayRef<Value * > Args,AttributeList * Attrs,Instruction * Pos,BuiltinFuncMangleInfo * Mangle,StringRef InstName,bool TakeFuncName)701 CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy,
702                       ArrayRef<Value *> Args, AttributeList *Attrs,
703                       Instruction *Pos, BuiltinFuncMangleInfo *Mangle,
704                       StringRef InstName, bool TakeFuncName) {
705 
706   auto F = getOrCreateFunction(M, RetTy, getTypes(Args), FuncName, Mangle,
707                                Attrs, TakeFuncName);
708   // Cannot assign a Name to void typed values
709   auto CI = CallInst::Create(F, Args, RetTy->isVoidTy() ? "" : InstName, Pos);
710   CI->setCallingConv(F->getCallingConv());
711   CI->setAttributes(F->getAttributes());
712   return CI;
713 }
714 
addCallInstSPIRV(Module * M,StringRef FuncName,Type * RetTy,ArrayRef<Value * > Args,AttributeList * Attrs,Instruction * Pos,StringRef InstName)715 CallInst *addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy,
716                            ArrayRef<Value *> Args, AttributeList *Attrs,
717                            Instruction *Pos, StringRef InstName) {
718   BuiltinFuncMangleInfo BtnInfo;
719   return addCallInst(M, FuncName, RetTy, Args, Attrs, Pos, &BtnInfo, InstName);
720 }
721 
isValidVectorSize(unsigned I)722 bool isValidVectorSize(unsigned I) {
723   return I == 2 || I == 3 || I == 4 || I == 8 || I == 16;
724 }
725 
addVector(Instruction * InsPos,ValueVecRange Range)726 Value *addVector(Instruction *InsPos, ValueVecRange Range) {
727   size_t VecSize = Range.second - Range.first;
728   if (VecSize == 1)
729     return *Range.first;
730   assert(isValidVectorSize(VecSize) && "Invalid vector size");
731   IRBuilder<> Builder(InsPos);
732   auto Vec = Builder.CreateVectorSplat(VecSize, *Range.first);
733   unsigned Index = 1;
734   for (++Range.first; Range.first != Range.second; ++Range.first, ++Index)
735     Vec = Builder.CreateInsertElement(
736         Vec, *Range.first,
737         ConstantInt::get(Type::getInt32Ty(InsPos->getContext()), Index, false));
738   return Vec;
739 }
740 
makeVector(Instruction * InsPos,std::vector<Value * > & Ops,ValueVecRange Range)741 void makeVector(Instruction *InsPos, std::vector<Value *> &Ops,
742                 ValueVecRange Range) {
743   auto Vec = addVector(InsPos, Range);
744   Ops.erase(Range.first, Range.second);
745   Ops.push_back(Vec);
746 }
747 
expandVector(Instruction * InsPos,std::vector<Value * > & Ops,size_t VecPos)748 void expandVector(Instruction *InsPos, std::vector<Value *> &Ops,
749                   size_t VecPos) {
750   auto Vec = Ops[VecPos];
751   auto *VT = dyn_cast<FixedVectorType>(Vec->getType());
752   if (!VT)
753     return;
754   size_t N = VT->getNumElements();
755   IRBuilder<> Builder(InsPos);
756   for (size_t I = 0; I != N; ++I)
757     Ops.insert(Ops.begin() + VecPos + I,
758                Builder.CreateExtractElement(
759                    Vec, ConstantInt::get(Type::getInt32Ty(InsPos->getContext()),
760                                          I, false)));
761   Ops.erase(Ops.begin() + VecPos + N);
762 }
763 
castToInt8Ptr(Constant * V,unsigned Addr=0)764 Constant *castToInt8Ptr(Constant *V, unsigned Addr = 0) {
765   return ConstantExpr::getBitCast(V, Type::getInt8PtrTy(V->getContext(), Addr));
766 }
767 
getInt8PtrTy(PointerType * T)768 PointerType *getInt8PtrTy(PointerType *T) {
769   return Type::getInt8PtrTy(T->getContext(), T->getAddressSpace());
770 }
771 
castToInt8Ptr(Value * V,Instruction * Pos)772 Value *castToInt8Ptr(Value *V, Instruction *Pos) {
773   return CastInst::CreatePointerCast(
774       V, getInt8PtrTy(cast<PointerType>(V->getType())), "", Pos);
775 }
776 
addBlockBind(Module * M,Function * InvokeFunc,Value * BlkCtx,Value * CtxLen,Value * CtxAlign,Instruction * InsPos,StringRef InstName)777 CallInst *addBlockBind(Module *M, Function *InvokeFunc, Value *BlkCtx,
778                        Value *CtxLen, Value *CtxAlign, Instruction *InsPos,
779                        StringRef InstName) {
780   auto BlkTy =
781       getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_BLOCK_T, SPIRAS_Private);
782   auto &Ctx = M->getContext();
783   Value *BlkArgs[] = {
784       castToInt8Ptr(InvokeFunc),
785       CtxLen ? CtxLen : UndefValue::get(Type::getInt32Ty(Ctx)),
786       CtxAlign ? CtxAlign : UndefValue::get(Type::getInt32Ty(Ctx)),
787       BlkCtx ? BlkCtx : UndefValue::get(Type::getInt8PtrTy(Ctx))};
788   return addCallInst(M, SPIR_INTRINSIC_BLOCK_BIND, BlkTy, BlkArgs, nullptr,
789                      InsPos, nullptr, InstName);
790 }
791 
getSizetType(Module * M)792 IntegerType *getSizetType(Module *M) {
793   return IntegerType::getIntNTy(M->getContext(),
794                                 M->getDataLayout().getPointerSizeInBits(0));
795 }
796 
getVoidFuncType(Module * M)797 Type *getVoidFuncType(Module *M) {
798   return FunctionType::get(Type::getVoidTy(M->getContext()), false);
799 }
800 
getVoidFuncPtrType(Module * M,unsigned AddrSpace)801 Type *getVoidFuncPtrType(Module *M, unsigned AddrSpace) {
802   return PointerType::get(getVoidFuncType(M), AddrSpace);
803 }
804 
getInt64(Module * M,int64_t Value)805 ConstantInt *getInt64(Module *M, int64_t Value) {
806   return ConstantInt::getSigned(Type::getInt64Ty(M->getContext()), Value);
807 }
808 
getUInt64(Module * M,uint64_t Value)809 ConstantInt *getUInt64(Module *M, uint64_t Value) {
810   return ConstantInt::get(Type::getInt64Ty(M->getContext()), Value, false);
811 }
812 
getFloat32(Module * M,float Value)813 Constant *getFloat32(Module *M, float Value) {
814   return ConstantFP::get(Type::getFloatTy(M->getContext()), Value);
815 }
816 
getInt32(Module * M,int Value)817 ConstantInt *getInt32(Module *M, int Value) {
818   return ConstantInt::get(Type::getInt32Ty(M->getContext()), Value, true);
819 }
820 
getUInt32(Module * M,unsigned Value)821 ConstantInt *getUInt32(Module *M, unsigned Value) {
822   return ConstantInt::get(Type::getInt32Ty(M->getContext()), Value, false);
823 }
824 
getInt(Module * M,int64_t Value)825 ConstantInt *getInt(Module *M, int64_t Value) {
826   return Value >> 32 ? getInt64(M, Value)
827                      : getInt32(M, static_cast<int32_t>(Value));
828 }
829 
getUInt(Module * M,uint64_t Value)830 ConstantInt *getUInt(Module *M, uint64_t Value) {
831   return Value >> 32 ? getUInt64(M, Value)
832                      : getUInt32(M, static_cast<uint32_t>(Value));
833 }
834 
getUInt16(Module * M,unsigned short Value)835 ConstantInt *getUInt16(Module *M, unsigned short Value) {
836   return ConstantInt::get(Type::getInt16Ty(M->getContext()), Value, false);
837 }
838 
getInt32(Module * M,const std::vector<int> & Values)839 std::vector<Value *> getInt32(Module *M, const std::vector<int> &Values) {
840   std::vector<Value *> V;
841   for (auto &I : Values)
842     V.push_back(getInt32(M, I));
843   return V;
844 }
845 
getSizet(Module * M,uint64_t Value)846 ConstantInt *getSizet(Module *M, uint64_t Value) {
847   return ConstantInt::get(getSizetType(M), Value, false);
848 }
849 
850 ///////////////////////////////////////////////////////////////////////////////
851 //
852 // Functions for getting metadata
853 //
854 ///////////////////////////////////////////////////////////////////////////////
getMDOperandAsInt(MDNode * N,unsigned I)855 int64_t getMDOperandAsInt(MDNode *N, unsigned I) {
856   return mdconst::dyn_extract<ConstantInt>(N->getOperand(I))->getZExtValue();
857 }
858 
859 // Additional helper function to be reused by getMDOperandAs* helpers
getMDOperandOrNull(MDNode * N,unsigned I)860 Metadata *getMDOperandOrNull(MDNode *N, unsigned I) {
861   if (!N)
862     return nullptr;
863   return N->getOperand(I);
864 }
865 
getMDOperandAsString(MDNode * N,unsigned I)866 std::string getMDOperandAsString(MDNode *N, unsigned I) {
867   if (auto *Str = dyn_cast_or_null<MDString>(getMDOperandOrNull(N, I)))
868     return Str->getString().str();
869   return "";
870 }
871 
getMDOperandAsMDNode(MDNode * N,unsigned I)872 MDNode *getMDOperandAsMDNode(MDNode *N, unsigned I) {
873   return dyn_cast_or_null<MDNode>(getMDOperandOrNull(N, I));
874 }
875 
getMDOperandAsType(MDNode * N,unsigned I)876 Type *getMDOperandAsType(MDNode *N, unsigned I) {
877   return cast<ValueAsMetadata>(N->getOperand(I))->getType();
878 }
879 
getNamedMDAsStringSet(Module * M,const std::string & MDName)880 std::set<std::string> getNamedMDAsStringSet(Module *M,
881                                             const std::string &MDName) {
882   NamedMDNode *NamedMD = M->getNamedMetadata(MDName);
883   std::set<std::string> StrSet;
884   if (!NamedMD)
885     return StrSet;
886 
887   assert(NamedMD->getNumOperands() > 0 && "Invalid SPIR");
888 
889   for (unsigned I = 0, E = NamedMD->getNumOperands(); I != E; ++I) {
890     MDNode *MD = NamedMD->getOperand(I);
891     if (!MD || MD->getNumOperands() == 0)
892       continue;
893     for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
894       StrSet.insert(getMDOperandAsString(MD, J));
895   }
896 
897   return StrSet;
898 }
899 
getSPIRVSource(Module * M)900 std::tuple<unsigned, unsigned, std::string> getSPIRVSource(Module *M) {
901   std::tuple<unsigned, unsigned, std::string> Tup;
902   if (auto N = SPIRVMDWalker(*M).getNamedMD(kSPIRVMD::Source).nextOp())
903     N.get(std::get<0>(Tup))
904         .get(std::get<1>(Tup))
905         .setQuiet(true)
906         .get(std::get<2>(Tup));
907   return Tup;
908 }
909 
mapUInt(Module * M,ConstantInt * I,std::function<unsigned (unsigned)> F)910 ConstantInt *mapUInt(Module *M, ConstantInt *I,
911                      std::function<unsigned(unsigned)> F) {
912   return ConstantInt::get(I->getType(), F(I->getZExtValue()), false);
913 }
914 
mapSInt(Module * M,ConstantInt * I,std::function<int (int)> F)915 ConstantInt *mapSInt(Module *M, ConstantInt *I, std::function<int(int)> F) {
916   return ConstantInt::get(I->getType(), F(I->getSExtValue()), true);
917 }
918 
isDecoratedSPIRVFunc(const Function * F,StringRef & UndecoratedName)919 bool isDecoratedSPIRVFunc(const Function *F, StringRef &UndecoratedName) {
920   if (!F->hasName() || !F->getName().startswith(kSPIRVName::Prefix))
921     return false;
922   UndecoratedName = F->getName();
923   return true;
924 }
925 
926 /// Get TypePrimitiveEnum for special OpenCL type except opencl.block.
getOCLTypePrimitiveEnum(StringRef TyName)927 SPIR::TypePrimitiveEnum getOCLTypePrimitiveEnum(StringRef TyName) {
928   return StringSwitch<SPIR::TypePrimitiveEnum>(TyName)
929       .Case("opencl.image1d_ro_t", SPIR::PRIMITIVE_IMAGE1D_RO_T)
930       .Case("opencl.image1d_array_ro_t", SPIR::PRIMITIVE_IMAGE1D_ARRAY_RO_T)
931       .Case("opencl.image1d_buffer_ro_t", SPIR::PRIMITIVE_IMAGE1D_BUFFER_RO_T)
932       .Case("opencl.image2d_ro_t", SPIR::PRIMITIVE_IMAGE2D_RO_T)
933       .Case("opencl.image2d_array_ro_t", SPIR::PRIMITIVE_IMAGE2D_ARRAY_RO_T)
934       .Case("opencl.image2d_depth_ro_t", SPIR::PRIMITIVE_IMAGE2D_DEPTH_RO_T)
935       .Case("opencl.image2d_array_depth_ro_t",
936             SPIR::PRIMITIVE_IMAGE2D_ARRAY_DEPTH_RO_T)
937       .Case("opencl.image2d_msaa_ro_t", SPIR::PRIMITIVE_IMAGE2D_MSAA_RO_T)
938       .Case("opencl.image2d_array_msaa_ro_t",
939             SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_RO_T)
940       .Case("opencl.image2d_msaa_depth_ro_t",
941             SPIR::PRIMITIVE_IMAGE2D_MSAA_DEPTH_RO_T)
942       .Case("opencl.image2d_array_msaa_depth_ro_t",
943             SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_DEPTH_RO_T)
944       .Case("opencl.image3d_ro_t", SPIR::PRIMITIVE_IMAGE3D_RO_T)
945       .Case("opencl.image1d_wo_t", SPIR::PRIMITIVE_IMAGE1D_WO_T)
946       .Case("opencl.image1d_array_wo_t", SPIR::PRIMITIVE_IMAGE1D_ARRAY_WO_T)
947       .Case("opencl.image1d_buffer_wo_t", SPIR::PRIMITIVE_IMAGE1D_BUFFER_WO_T)
948       .Case("opencl.image2d_wo_t", SPIR::PRIMITIVE_IMAGE2D_WO_T)
949       .Case("opencl.image2d_array_wo_t", SPIR::PRIMITIVE_IMAGE2D_ARRAY_WO_T)
950       .Case("opencl.image2d_depth_wo_t", SPIR::PRIMITIVE_IMAGE2D_DEPTH_WO_T)
951       .Case("opencl.image2d_array_depth_wo_t",
952             SPIR::PRIMITIVE_IMAGE2D_ARRAY_DEPTH_WO_T)
953       .Case("opencl.image2d_msaa_wo_t", SPIR::PRIMITIVE_IMAGE2D_MSAA_WO_T)
954       .Case("opencl.image2d_array_msaa_wo_t",
955             SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_WO_T)
956       .Case("opencl.image2d_msaa_depth_wo_t",
957             SPIR::PRIMITIVE_IMAGE2D_MSAA_DEPTH_WO_T)
958       .Case("opencl.image2d_array_msaa_depth_wo_t",
959             SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_DEPTH_WO_T)
960       .Case("opencl.image3d_wo_t", SPIR::PRIMITIVE_IMAGE3D_WO_T)
961       .Case("opencl.image1d_rw_t", SPIR::PRIMITIVE_IMAGE1D_RW_T)
962       .Case("opencl.image1d_array_rw_t", SPIR::PRIMITIVE_IMAGE1D_ARRAY_RW_T)
963       .Case("opencl.image1d_buffer_rw_t", SPIR::PRIMITIVE_IMAGE1D_BUFFER_RW_T)
964       .Case("opencl.image2d_rw_t", SPIR::PRIMITIVE_IMAGE2D_RW_T)
965       .Case("opencl.image2d_array_rw_t", SPIR::PRIMITIVE_IMAGE2D_ARRAY_RW_T)
966       .Case("opencl.image2d_depth_rw_t", SPIR::PRIMITIVE_IMAGE2D_DEPTH_RW_T)
967       .Case("opencl.image2d_array_depth_rw_t",
968             SPIR::PRIMITIVE_IMAGE2D_ARRAY_DEPTH_RW_T)
969       .Case("opencl.image2d_msaa_rw_t", SPIR::PRIMITIVE_IMAGE2D_MSAA_RW_T)
970       .Case("opencl.image2d_array_msaa_rw_t",
971             SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_RW_T)
972       .Case("opencl.image2d_msaa_depth_rw_t",
973             SPIR::PRIMITIVE_IMAGE2D_MSAA_DEPTH_RW_T)
974       .Case("opencl.image2d_array_msaa_depth_rw_t",
975             SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_DEPTH_RW_T)
976       .Case("opencl.image3d_rw_t", SPIR::PRIMITIVE_IMAGE3D_RW_T)
977       .Case("opencl.event_t", SPIR::PRIMITIVE_EVENT_T)
978       .Case("opencl.pipe_ro_t", SPIR::PRIMITIVE_PIPE_RO_T)
979       .Case("opencl.pipe_wo_t", SPIR::PRIMITIVE_PIPE_WO_T)
980       .Case("opencl.reserve_id_t", SPIR::PRIMITIVE_RESERVE_ID_T)
981       .Case("opencl.queue_t", SPIR::PRIMITIVE_QUEUE_T)
982       .Case("opencl.clk_event_t", SPIR::PRIMITIVE_CLK_EVENT_T)
983       .Case("opencl.sampler_t", SPIR::PRIMITIVE_SAMPLER_T)
984       .Case("struct.ndrange_t", SPIR::PRIMITIVE_NDRANGE_T)
985       .Case("opencl.intel_sub_group_avc_mce_payload_t",
986             SPIR::PRIMITIVE_SUB_GROUP_AVC_MCE_PAYLOAD_T)
987       .Case("opencl.intel_sub_group_avc_ime_payload_t",
988             SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_PAYLOAD_T)
989       .Case("opencl.intel_sub_group_avc_ref_payload_t",
990             SPIR::PRIMITIVE_SUB_GROUP_AVC_REF_PAYLOAD_T)
991       .Case("opencl.intel_sub_group_avc_sic_payload_t",
992             SPIR::PRIMITIVE_SUB_GROUP_AVC_SIC_PAYLOAD_T)
993       .Case("opencl.intel_sub_group_avc_mce_result_t",
994             SPIR::PRIMITIVE_SUB_GROUP_AVC_MCE_RESULT_T)
995       .Case("opencl.intel_sub_group_avc_ime_result_t",
996             SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_RESULT_T)
997       .Case("opencl.intel_sub_group_avc_ref_result_t",
998             SPIR::PRIMITIVE_SUB_GROUP_AVC_REF_RESULT_T)
999       .Case("opencl.intel_sub_group_avc_sic_result_t",
1000             SPIR::PRIMITIVE_SUB_GROUP_AVC_SIC_RESULT_T)
1001       .Case(
1002           "opencl.intel_sub_group_avc_ime_result_single_reference_streamout_t",
1003           SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_SINGLE_REF_STREAMOUT_T)
1004       .Case("opencl.intel_sub_group_avc_ime_result_dual_reference_streamout_t",
1005             SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_DUAL_REF_STREAMOUT_T)
1006       .Case("opencl.intel_sub_group_avc_ime_single_reference_streamin_t",
1007             SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_SINGLE_REF_STREAMIN_T)
1008       .Case("opencl.intel_sub_group_avc_ime_dual_reference_streamin_t",
1009             SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_DUAL_REF_STREAMIN_T)
1010       .Default(SPIR::PRIMITIVE_NONE);
1011 }
1012 /// Translates LLVM type to descriptor for mangler.
1013 /// \param Signed indicates integer type should be translated as signed.
1014 /// \param VoidPtr indicates i8* should be translated as void*.
transTypeDesc(Type * Ty,const BuiltinArgTypeMangleInfo & Info)1015 static SPIR::RefParamType transTypeDesc(Type *Ty,
1016                                         const BuiltinArgTypeMangleInfo &Info) {
1017   bool Signed = Info.IsSigned;
1018   unsigned Attr = Info.Attr;
1019   bool VoidPtr = Info.IsVoidPtr;
1020   if (Info.IsEnum)
1021     return SPIR::RefParamType(new SPIR::PrimitiveType(Info.Enum));
1022   if (Info.IsSampler)
1023     return SPIR::RefParamType(
1024         new SPIR::PrimitiveType(SPIR::PRIMITIVE_SAMPLER_T));
1025   if (Info.IsAtomic && !Ty->isPointerTy()) {
1026     BuiltinArgTypeMangleInfo DTInfo = Info;
1027     DTInfo.IsAtomic = false;
1028     return SPIR::RefParamType(new SPIR::AtomicType(transTypeDesc(Ty, DTInfo)));
1029   }
1030   if (auto *IntTy = dyn_cast<IntegerType>(Ty)) {
1031     switch (IntTy->getBitWidth()) {
1032     case 1:
1033       return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_BOOL));
1034     case 8:
1035       return SPIR::RefParamType(new SPIR::PrimitiveType(
1036           Signed ? SPIR::PRIMITIVE_CHAR : SPIR::PRIMITIVE_UCHAR));
1037     case 16:
1038       return SPIR::RefParamType(new SPIR::PrimitiveType(
1039           Signed ? SPIR::PRIMITIVE_SHORT : SPIR::PRIMITIVE_USHORT));
1040     case 32:
1041       return SPIR::RefParamType(new SPIR::PrimitiveType(
1042           Signed ? SPIR::PRIMITIVE_INT : SPIR::PRIMITIVE_UINT));
1043     case 64:
1044       return SPIR::RefParamType(new SPIR::PrimitiveType(
1045           Signed ? SPIR::PRIMITIVE_LONG : SPIR::PRIMITIVE_ULONG));
1046     default:
1047       llvm_unreachable("invliad int size");
1048     }
1049   }
1050   if (Ty->isVoidTy())
1051     return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_VOID));
1052   if (Ty->isHalfTy())
1053     return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_HALF));
1054   if (Ty->isFloatTy())
1055     return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_FLOAT));
1056   if (Ty->isDoubleTy())
1057     return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_DOUBLE));
1058   if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
1059     return SPIR::RefParamType(new SPIR::VectorType(
1060         transTypeDesc(VecTy->getElementType(), Info), VecTy->getNumElements()));
1061   }
1062   if (Ty->isArrayTy()) {
1063     return transTypeDesc(PointerType::get(Ty->getArrayElementType(), 0), Info);
1064   }
1065   if (Ty->isStructTy()) {
1066     auto Name = Ty->getStructName();
1067     std::string Tmp;
1068 
1069     if (Name.startswith(kLLVMTypeName::StructPrefix))
1070       Name = Name.drop_front(strlen(kLLVMTypeName::StructPrefix));
1071     if (Name.startswith(kSPIRVTypeName::PrefixAndDelim)) {
1072       Name = Name.substr(sizeof(kSPIRVTypeName::PrefixAndDelim) - 1);
1073       Tmp = Name.str();
1074       auto Pos = Tmp.find(kSPIRVTypeName::Delimiter); // first dot
1075       while (Pos != std::string::npos) {
1076         Tmp[Pos] = '_';
1077         Pos = Tmp.find(kSPIRVTypeName::Delimiter, Pos);
1078       }
1079       Name = Tmp = kSPIRVName::Prefix + Tmp;
1080     }
1081     // ToDo: Create a better unique Name for struct without Name
1082     if (Name.empty()) {
1083       std::ostringstream OS;
1084       OS << reinterpret_cast<size_t>(Ty);
1085       Name = Tmp = std::string("struct_") + OS.str();
1086     }
1087     return SPIR::RefParamType(new SPIR::UserDefinedType(Name.str()));
1088   }
1089 
1090   if (Ty->isPointerTy()) {
1091     auto ET = Ty->getPointerElementType();
1092     SPIR::ParamType *EPT = nullptr;
1093     if (isa<FunctionType>(ET)) {
1094       assert(isVoidFuncTy(cast<FunctionType>(ET)) && "Not supported");
1095       EPT = new SPIR::BlockType;
1096     } else if (auto StructTy = dyn_cast<StructType>(ET)) {
1097       LLVM_DEBUG(dbgs() << "ptr to struct: " << *Ty << '\n');
1098       auto TyName = StructTy->getStructName();
1099       if (TyName.startswith(kSPR2TypeName::OCLPrefix)) {
1100         auto DelimPos = TyName.find_first_of(kSPR2TypeName::Delimiter,
1101                                              strlen(kSPR2TypeName::OCLPrefix));
1102         if (DelimPos != StringRef::npos)
1103           TyName = TyName.substr(0, DelimPos);
1104       }
1105       LLVM_DEBUG(dbgs() << "  type Name: " << TyName << '\n');
1106 
1107       auto Prim = getOCLTypePrimitiveEnum(TyName);
1108       if (StructTy->isOpaque()) {
1109         if (TyName == "opencl.block") {
1110           auto BlockTy = new SPIR::BlockType;
1111           // Handle block with local memory arguments according to OpenCL 2.0
1112           // spec.
1113           if (Info.IsLocalArgBlock) {
1114             SPIR::RefParamType VoidTyRef(
1115                 new SPIR::PrimitiveType(SPIR::PRIMITIVE_VOID));
1116             auto VoidPtrTy = new SPIR::PointerType(VoidTyRef);
1117             VoidPtrTy->setAddressSpace(SPIR::ATTR_LOCAL);
1118             // "__local void *"
1119             BlockTy->setParam(0, SPIR::RefParamType(VoidPtrTy));
1120             // "..."
1121             BlockTy->setParam(1, SPIR::RefParamType(new SPIR::PrimitiveType(
1122                                      SPIR::PRIMITIVE_VAR_ARG)));
1123           }
1124           EPT = BlockTy;
1125         } else if (Prim != SPIR::PRIMITIVE_NONE) {
1126           if (Prim == SPIR::PRIMITIVE_PIPE_RO_T ||
1127               Prim == SPIR::PRIMITIVE_PIPE_WO_T) {
1128             SPIR::RefParamType OpaqueTyRef(new SPIR::PrimitiveType(Prim));
1129             auto OpaquePtrTy = new SPIR::PointerType(OpaqueTyRef);
1130             OpaquePtrTy->setAddressSpace(getOCLOpaqueTypeAddrSpace(Prim));
1131             EPT = OpaquePtrTy;
1132           } else {
1133             EPT = new SPIR::PrimitiveType(Prim);
1134           }
1135         }
1136       } else if (Prim == SPIR::PRIMITIVE_NDRANGE_T)
1137         // ndrange_t is not opaque type
1138         EPT = new SPIR::PrimitiveType(SPIR::PRIMITIVE_NDRANGE_T);
1139     }
1140     if (EPT)
1141       return SPIR::RefParamType(EPT);
1142 
1143     if (VoidPtr && ET->isIntegerTy(8))
1144       ET = Type::getVoidTy(ET->getContext());
1145     auto PT = new SPIR::PointerType(transTypeDesc(ET, Info));
1146     PT->setAddressSpace(static_cast<SPIR::TypeAttributeEnum>(
1147         Ty->getPointerAddressSpace() + (unsigned)SPIR::ATTR_ADDR_SPACE_FIRST));
1148     for (unsigned I = SPIR::ATTR_QUALIFIER_FIRST, E = SPIR::ATTR_QUALIFIER_LAST;
1149          I <= E; ++I)
1150       PT->setQualifier(static_cast<SPIR::TypeAttributeEnum>(I), I & Attr);
1151     return SPIR::RefParamType(PT);
1152   }
1153   LLVM_DEBUG(dbgs() << "[transTypeDesc] " << *Ty << '\n');
1154   assert(0 && "not implemented");
1155   return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_INT));
1156 }
1157 
getScalarOrArray(Value * V,unsigned Size,Instruction * Pos)1158 Value *getScalarOrArray(Value *V, unsigned Size, Instruction *Pos) {
1159   if (!V->getType()->isPointerTy())
1160     return V;
1161   auto GEP = cast<GEPOperator>(V);
1162   assert(GEP->getNumOperands() == 3 && "must be a GEP from an array");
1163   assert(GEP->getSourceElementType()->getArrayNumElements() == Size);
1164   assert(dyn_cast<ConstantInt>(GEP->getOperand(1))->getZExtValue() == 0);
1165   assert(dyn_cast<ConstantInt>(GEP->getOperand(2))->getZExtValue() == 0);
1166   return new LoadInst(GEP->getSourceElementType(), GEP->getOperand(0), "", Pos);
1167 }
1168 
getScalarOrVectorConstantInt(Type * T,uint64_t V,bool IsSigned)1169 Constant *getScalarOrVectorConstantInt(Type *T, uint64_t V, bool IsSigned) {
1170   if (auto IT = dyn_cast<IntegerType>(T))
1171     return ConstantInt::get(IT, V);
1172   if (auto VT = dyn_cast<FixedVectorType>(T)) {
1173     std::vector<Constant *> EV(
1174         VT->getNumElements(),
1175         getScalarOrVectorConstantInt(VT->getElementType(), V, IsSigned));
1176     return ConstantVector::get(EV);
1177   }
1178   llvm_unreachable("Invalid type");
1179   return nullptr;
1180 }
1181 
getScalarOrArrayConstantInt(Instruction * Pos,Type * T,unsigned Len,uint64_t V,bool IsSigned)1182 Value *getScalarOrArrayConstantInt(Instruction *Pos, Type *T, unsigned Len,
1183                                    uint64_t V, bool IsSigned) {
1184   if (auto IT = dyn_cast<IntegerType>(T)) {
1185     assert(Len == 1 && "Invalid length");
1186     return ConstantInt::get(IT, V, IsSigned);
1187   }
1188   if (auto PT = dyn_cast<PointerType>(T)) {
1189     auto ET = PT->getPointerElementType();
1190     auto AT = ArrayType::get(ET, Len);
1191     std::vector<Constant *> EV(Len, ConstantInt::get(ET, V, IsSigned));
1192     auto CA = ConstantArray::get(AT, EV);
1193     auto Alloca = new AllocaInst(AT, 0, "", Pos);
1194     new StoreInst(CA, Alloca, Pos);
1195     auto Zero = ConstantInt::getNullValue(Type::getInt32Ty(T->getContext()));
1196     Value *Index[] = {Zero, Zero};
1197     auto *Ret = GetElementPtrInst::CreateInBounds(AT, Alloca, Index, "", Pos);
1198     LLVM_DEBUG(dbgs() << "[getScalarOrArrayConstantInt] Alloca: " << *Alloca
1199                       << ", Return: " << *Ret << '\n');
1200     return Ret;
1201   }
1202   if (auto AT = dyn_cast<ArrayType>(T)) {
1203     auto ET = AT->getArrayElementType();
1204     assert(AT->getArrayNumElements() == Len);
1205     std::vector<Constant *> EV(Len, ConstantInt::get(ET, V, IsSigned));
1206     auto Ret = ConstantArray::get(AT, EV);
1207     LLVM_DEBUG(dbgs() << "[getScalarOrArrayConstantInt] Array type: " << *AT
1208                       << ", Return: " << *Ret << '\n');
1209     return Ret;
1210   }
1211   llvm_unreachable("Invalid type");
1212   return nullptr;
1213 }
1214 
dumpUsers(Value * V,StringRef Prompt)1215 void dumpUsers(Value *V, StringRef Prompt) {
1216   if (!V)
1217     return;
1218   LLVM_DEBUG(dbgs() << Prompt << " Users of " << *V << " :\n");
1219   for (auto UI = V->user_begin(), UE = V->user_end(); UI != UE; ++UI)
1220     LLVM_DEBUG(dbgs() << "  " << **UI << '\n');
1221 }
1222 
getSPIRVTypeName(StringRef BaseName,StringRef Postfixes)1223 std::string getSPIRVTypeName(StringRef BaseName, StringRef Postfixes) {
1224   assert(!BaseName.empty() && "Invalid SPIR-V type Name");
1225   auto TN = std::string(kSPIRVTypeName::PrefixAndDelim) + BaseName.str();
1226   if (Postfixes.empty())
1227     return TN;
1228   return TN + kSPIRVTypeName::Delimiter + Postfixes.str();
1229 }
1230 
isSPIRVConstantName(StringRef TyName)1231 bool isSPIRVConstantName(StringRef TyName) {
1232   if (TyName == getSPIRVTypeName(kSPIRVTypeName::ConstantSampler) ||
1233       TyName == getSPIRVTypeName(kSPIRVTypeName::ConstantPipeStorage))
1234     return true;
1235 
1236   return false;
1237 }
1238 
getSPIRVTypeByChangeBaseTypeName(Module * M,Type * T,StringRef OldName,StringRef NewName)1239 Type *getSPIRVTypeByChangeBaseTypeName(Module *M, Type *T, StringRef OldName,
1240                                        StringRef NewName) {
1241   StringRef Postfixes;
1242   if (isSPIRVType(T, OldName, &Postfixes))
1243     return getOrCreateOpaquePtrType(M, getSPIRVTypeName(NewName, Postfixes));
1244   LLVM_DEBUG(dbgs() << " Invalid SPIR-V type " << *T << '\n');
1245   llvm_unreachable("Invalid SPIR-V type");
1246   return nullptr;
1247 }
1248 
getSPIRVImageTypePostfixes(StringRef SampledType,SPIRVTypeImageDescriptor Desc,SPIRVAccessQualifierKind Acc)1249 std::string getSPIRVImageTypePostfixes(StringRef SampledType,
1250                                        SPIRVTypeImageDescriptor Desc,
1251                                        SPIRVAccessQualifierKind Acc) {
1252   std::string S;
1253   raw_string_ostream OS(S);
1254   OS << kSPIRVTypeName::PostfixDelim << SampledType
1255      << kSPIRVTypeName::PostfixDelim << Desc.Dim << kSPIRVTypeName::PostfixDelim
1256      << Desc.Depth << kSPIRVTypeName::PostfixDelim << Desc.Arrayed
1257      << kSPIRVTypeName::PostfixDelim << Desc.MS << kSPIRVTypeName::PostfixDelim
1258      << Desc.Sampled << kSPIRVTypeName::PostfixDelim << Desc.Format
1259      << kSPIRVTypeName::PostfixDelim << Acc;
1260   return OS.str();
1261 }
1262 
getSPIRVImageSampledTypeName(SPIRVType * Ty)1263 std::string getSPIRVImageSampledTypeName(SPIRVType *Ty) {
1264   switch (Ty->getOpCode()) {
1265   case OpTypeVoid:
1266     return kSPIRVImageSampledTypeName::Void;
1267   case OpTypeInt:
1268     if (Ty->getIntegerBitWidth() == 32) {
1269       if (static_cast<SPIRVTypeInt *>(Ty)->isSigned())
1270         return kSPIRVImageSampledTypeName::Int;
1271       else
1272         return kSPIRVImageSampledTypeName::UInt;
1273     }
1274     break;
1275   case OpTypeFloat:
1276     switch (Ty->getFloatBitWidth()) {
1277     case 16:
1278       return kSPIRVImageSampledTypeName::Half;
1279     case 32:
1280       return kSPIRVImageSampledTypeName::Float;
1281     default:
1282       break;
1283     }
1284     break;
1285   default:
1286     break;
1287   }
1288   llvm_unreachable("Invalid sampled type for image");
1289   return std::string();
1290 }
1291 
1292 // ToDo: Find a way to represent uint sampled type in LLVM, maybe an
1293 //      opaque type.
getLLVMTypeForSPIRVImageSampledTypePostfix(StringRef Postfix,LLVMContext & Ctx)1294 Type *getLLVMTypeForSPIRVImageSampledTypePostfix(StringRef Postfix,
1295                                                  LLVMContext &Ctx) {
1296   if (Postfix == kSPIRVImageSampledTypeName::Void)
1297     return Type::getVoidTy(Ctx);
1298   if (Postfix == kSPIRVImageSampledTypeName::Float)
1299     return Type::getFloatTy(Ctx);
1300   if (Postfix == kSPIRVImageSampledTypeName::Half)
1301     return Type::getHalfTy(Ctx);
1302   if (Postfix == kSPIRVImageSampledTypeName::Int ||
1303       Postfix == kSPIRVImageSampledTypeName::UInt)
1304     return Type::getInt32Ty(Ctx);
1305   llvm_unreachable("Invalid sampled type postfix");
1306   return nullptr;
1307 }
1308 
getImageBaseTypeName(StringRef Name)1309 std::string getImageBaseTypeName(StringRef Name) {
1310 
1311   SmallVector<StringRef, 4> SubStrs;
1312   const char Delims[] = {kSPR2TypeName::Delimiter, 0};
1313   Name.split(SubStrs, Delims);
1314   if (Name.startswith(kSPR2TypeName::OCLPrefix)) {
1315     Name = SubStrs[1];
1316   } else {
1317     Name = SubStrs[0];
1318   }
1319 
1320   std::string ImageTyName{Name};
1321   if (hasAccessQualifiedName(Name))
1322     ImageTyName.erase(ImageTyName.size() - 5, 3);
1323 
1324   return ImageTyName;
1325 }
1326 
mapOCLTypeNameToSPIRV(StringRef Name,StringRef Acc)1327 std::string mapOCLTypeNameToSPIRV(StringRef Name, StringRef Acc) {
1328   std::string BaseTy;
1329   std::string Postfixes;
1330   raw_string_ostream OS(Postfixes);
1331   if (Name.startswith(kSPR2TypeName::ImagePrefix)) {
1332     std::string ImageTyName = getImageBaseTypeName(Name);
1333     auto Desc = map<SPIRVTypeImageDescriptor>(ImageTyName);
1334     LLVM_DEBUG(dbgs() << "[trans image type] " << Name << " => "
1335                       << "(" << (unsigned)Desc.Dim << ", " << Desc.Depth << ", "
1336                       << Desc.Arrayed << ", " << Desc.MS << ", " << Desc.Sampled
1337                       << ", " << Desc.Format << ")\n");
1338 
1339     BaseTy = kSPIRVTypeName::Image;
1340     OS << getSPIRVImageTypePostfixes(
1341         kSPIRVImageSampledTypeName::Void, Desc,
1342         SPIRSPIRVAccessQualifierMap::map(Acc.str()));
1343   } else {
1344     LLVM_DEBUG(dbgs() << "Mapping of " << Name << " is not implemented\n");
1345     llvm_unreachable("Not implemented");
1346   }
1347   return getSPIRVTypeName(BaseTy, OS.str());
1348 }
1349 
eraseIfNoUse(Function * F)1350 bool eraseIfNoUse(Function *F) {
1351   bool Changed = false;
1352   if (!F)
1353     return Changed;
1354   if (!GlobalValue::isInternalLinkage(F->getLinkage()) && !F->isDeclaration())
1355     return Changed;
1356 
1357   dumpUsers(F, "[eraseIfNoUse] ");
1358   for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
1359     auto U = *UI++;
1360     if (auto CE = dyn_cast<ConstantExpr>(U)) {
1361       if (CE->use_empty()) {
1362         CE->dropAllReferences();
1363         Changed = true;
1364       }
1365     }
1366   }
1367   if (F->use_empty()) {
1368     LLVM_DEBUG(dbgs() << "Erase "; F->printAsOperand(dbgs()); dbgs() << '\n');
1369     F->eraseFromParent();
1370     Changed = true;
1371   }
1372   return Changed;
1373 }
1374 
eraseIfNoUse(Value * V)1375 void eraseIfNoUse(Value *V) {
1376   if (!V->use_empty())
1377     return;
1378   if (Constant *C = dyn_cast<Constant>(V)) {
1379     C->destroyConstant();
1380     return;
1381   }
1382   if (Instruction *I = dyn_cast<Instruction>(V)) {
1383     if (!I->mayHaveSideEffects())
1384       I->eraseFromParent();
1385   }
1386   eraseIfNoUse(dyn_cast<Function>(V));
1387 }
1388 
eraseUselessFunctions(Module * M)1389 bool eraseUselessFunctions(Module *M) {
1390   bool Changed = false;
1391   for (auto I = M->begin(), E = M->end(); I != E;)
1392     Changed |= eraseIfNoUse(&(*I++));
1393   return Changed;
1394 }
1395 
1396 // The mangling algorithm follows OpenCL pipe built-ins clang 3.8 CodeGen rules.
1397 static SPIR::MangleError
manglePipeOrAddressSpaceCastBuiltin(const SPIR::FunctionDescriptor & Fd,std::string & MangledName)1398 manglePipeOrAddressSpaceCastBuiltin(const SPIR::FunctionDescriptor &Fd,
1399                                     std::string &MangledName) {
1400   assert(OCLUtil::isPipeOrAddressSpaceCastBI(Fd.Name) &&
1401          "Method is expected to be called only for pipe and address space cast "
1402          "builtins!");
1403   if (Fd.isNull()) {
1404     MangledName.assign(SPIR::FunctionDescriptor::nullString());
1405     return SPIR::MANGLE_NULL_FUNC_DESCRIPTOR;
1406   }
1407   MangledName.assign("__" + Fd.Name);
1408   return SPIR::MANGLE_SUCCESS;
1409 }
1410 
mangleBuiltin(StringRef UniqName,ArrayRef<Type * > ArgTypes,BuiltinFuncMangleInfo * BtnInfo)1411 std::string mangleBuiltin(StringRef UniqName, ArrayRef<Type *> ArgTypes,
1412                           BuiltinFuncMangleInfo *BtnInfo) {
1413   if (!BtnInfo)
1414     return std::string(UniqName);
1415   BtnInfo->init(UniqName);
1416   std::string MangledName;
1417   LLVM_DEBUG(dbgs() << "[mangle] " << UniqName << " => ");
1418   SPIR::FunctionDescriptor FD;
1419   FD.Name = BtnInfo->getUnmangledName();
1420   bool BIVarArgNegative = BtnInfo->getVarArg() < 0;
1421 
1422   if (ArgTypes.empty()) {
1423     // Function signature cannot be ()(void, ...) so if there is an ellipsis
1424     // it must be ()(...)
1425     if (BIVarArgNegative) {
1426       FD.Parameters.emplace_back(
1427           SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_VOID)));
1428     }
1429   } else {
1430     for (unsigned I = 0, E = BIVarArgNegative ? ArgTypes.size()
1431                                               : (unsigned)BtnInfo->getVarArg();
1432          I != E; ++I) {
1433       auto T = ArgTypes[I];
1434       FD.Parameters.emplace_back(
1435           transTypeDesc(T, BtnInfo->getTypeMangleInfo(I)));
1436     }
1437   }
1438   // Ellipsis must be the last argument of any function
1439   if (!BIVarArgNegative) {
1440     assert((unsigned)BtnInfo->getVarArg() <= ArgTypes.size() &&
1441            "invalid index of an ellipsis");
1442     FD.Parameters.emplace_back(
1443         SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_VAR_ARG)));
1444   }
1445 
1446 #if defined(SPIRV_SPIR20_MANGLING_REQUIREMENTS)
1447   SPIR::NameMangler Mangler(SPIR::SPIR20);
1448   Mangler.mangle(FD, MangledName);
1449 #else
1450   if (OCLUtil::isPipeOrAddressSpaceCastBI(BtnInfo->getUnmangledName())) {
1451     manglePipeOrAddressSpaceCastBuiltin(FD, MangledName);
1452   } else {
1453     SPIR::NameMangler Mangler(SPIR::SPIR20);
1454     Mangler.mangle(FD, MangledName);
1455   }
1456 #endif
1457 
1458   LLVM_DEBUG(dbgs() << MangledName << '\n');
1459   return MangledName;
1460 }
1461 
1462 /// Check if access qualifier is encoded in the type Name.
hasAccessQualifiedName(StringRef TyName)1463 bool hasAccessQualifiedName(StringRef TyName) {
1464   if (TyName.size() < 5)
1465     return false;
1466   auto Acc = TyName.substr(TyName.size() - 5, 3);
1467   return llvm::StringSwitch<bool>(Acc)
1468       .Case(kAccessQualPostfix::ReadOnly, true)
1469       .Case(kAccessQualPostfix::WriteOnly, true)
1470       .Case(kAccessQualPostfix::ReadWrite, true)
1471       .Default(false);
1472 }
1473 
getAccessQualifier(StringRef TyName)1474 SPIRVAccessQualifierKind getAccessQualifier(StringRef TyName) {
1475   return SPIRSPIRVAccessQualifierMap::map(
1476       getAccessQualifierFullName(TyName).str());
1477 }
1478 
getAccessQualifierPostfix(SPIRVAccessQualifierKind Access)1479 StringRef getAccessQualifierPostfix(SPIRVAccessQualifierKind Access) {
1480   switch (Access) {
1481   case AccessQualifierReadOnly:
1482     return kAccessQualPostfix::ReadOnly;
1483   case AccessQualifierWriteOnly:
1484     return kAccessQualPostfix::WriteOnly;
1485   case AccessQualifierReadWrite:
1486     return kAccessQualPostfix::ReadWrite;
1487   default:
1488     assert(false && "Unrecognized access qualifier!");
1489     return kAccessQualPostfix::ReadWrite;
1490   }
1491 }
1492 
1493 /// Get access qualifier from the type Name.
getAccessQualifierFullName(StringRef TyName)1494 StringRef getAccessQualifierFullName(StringRef TyName) {
1495   assert(hasAccessQualifiedName(TyName) &&
1496          "Type is not qualified with access.");
1497   auto Acc = TyName.substr(TyName.size() - 5, 3);
1498   return llvm::StringSwitch<StringRef>(Acc)
1499       .Case(kAccessQualPostfix::ReadOnly, kAccessQualName::ReadOnly)
1500       .Case(kAccessQualPostfix::WriteOnly, kAccessQualName::WriteOnly)
1501       .Case(kAccessQualPostfix::ReadWrite, kAccessQualName::ReadWrite);
1502 }
1503 
1504 /// Translates OpenCL image type names to SPIR-V.
getSPIRVImageTypeFromOCL(Module * M,Type * ImageTy)1505 Type *getSPIRVImageTypeFromOCL(Module *M, Type *ImageTy) {
1506   assert(isOCLImageType(ImageTy) && "Unsupported type");
1507   auto ImageTypeName = ImageTy->getPointerElementType()->getStructName();
1508   StringRef Acc = kAccessQualName::ReadOnly;
1509   if (hasAccessQualifiedName(ImageTypeName))
1510     Acc = getAccessQualifierFullName(ImageTypeName);
1511   return getOrCreateOpaquePtrType(M, mapOCLTypeNameToSPIRV(ImageTypeName, Acc));
1512 }
1513 
getOCLClkEventType(Module * M)1514 llvm::PointerType *getOCLClkEventType(Module *M) {
1515   return getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_CLK_EVENT_T,
1516                                   SPIRAS_Private);
1517 }
1518 
getOCLClkEventPtrType(Module * M)1519 llvm::PointerType *getOCLClkEventPtrType(Module *M) {
1520   return PointerType::get(getOCLClkEventType(M), SPIRAS_Generic);
1521 }
1522 
getOCLNullClkEventPtr(Module * M)1523 llvm::Constant *getOCLNullClkEventPtr(Module *M) {
1524   return Constant::getNullValue(getOCLClkEventPtrType(M));
1525 }
1526 
hasLoopMetadata(const Module * M)1527 bool hasLoopMetadata(const Module *M) {
1528   for (const Function &F : *M)
1529     for (const BasicBlock &BB : F) {
1530       const Instruction *Term = BB.getTerminator();
1531       if (Term && Term->getMetadata("llvm.loop"))
1532         return true;
1533     }
1534   return false;
1535 }
1536 
isSPIRVOCLExtInst(const CallInst * CI,OCLExtOpKind * ExtOp)1537 bool isSPIRVOCLExtInst(const CallInst *CI, OCLExtOpKind *ExtOp) {
1538   StringRef DemangledName;
1539   if (!oclIsBuiltin(CI->getCalledFunction()->getName(), DemangledName))
1540     return false;
1541   StringRef S = DemangledName;
1542   if (!S.startswith(kSPIRVName::Prefix))
1543     return false;
1544   S = S.drop_front(strlen(kSPIRVName::Prefix));
1545   auto Loc = S.find(kSPIRVPostfix::Divider);
1546   auto ExtSetName = S.substr(0, Loc);
1547   SPIRVExtInstSetKind Set = SPIRVEIS_Count;
1548   if (!SPIRVExtSetShortNameMap::rfind(ExtSetName.str(), &Set))
1549     return false;
1550 
1551   if (Set != SPIRVEIS_OpenCL)
1552     return false;
1553 
1554   auto ExtOpName = S.substr(Loc + 1);
1555   auto PostFixPos = ExtOpName.find("_R");
1556   ExtOpName = ExtOpName.substr(0, PostFixPos);
1557 
1558   OCLExtOpKind EOC;
1559   if (!OCLExtOpMap::rfind(ExtOpName.str(), &EOC))
1560     return false;
1561 
1562   *ExtOp = EOC;
1563   return true;
1564 }
1565 
decodeSPIRVTypeName(StringRef Name,SmallVectorImpl<std::string> & Strs)1566 std::string decodeSPIRVTypeName(StringRef Name,
1567                                 SmallVectorImpl<std::string> &Strs) {
1568   SmallVector<StringRef, 4> SubStrs;
1569   const char Delim[] = {kSPIRVTypeName::Delimiter, 0};
1570   Name.split(SubStrs, Delim, -1, true);
1571   assert(SubStrs.size() >= 2 && "Invalid SPIRV type name");
1572   assert(SubStrs[0] == kSPIRVTypeName::Prefix && "Invalid prefix");
1573   assert((SubStrs.size() == 2 || !SubStrs[2].empty()) && "Invalid postfix");
1574 
1575   if (SubStrs.size() > 2) {
1576     const char PostDelim[] = {kSPIRVTypeName::PostfixDelim, 0};
1577     SmallVector<StringRef, 4> Postfixes;
1578     SubStrs[2].split(Postfixes, PostDelim, -1, true);
1579     assert(Postfixes.size() > 1 && Postfixes[0].empty() && "Invalid postfix");
1580     for (unsigned I = 1, E = Postfixes.size(); I != E; ++I)
1581       Strs.push_back(std::string(Postfixes[I]).c_str());
1582   }
1583   return SubStrs[1].str();
1584 }
1585 
1586 // Returns true if type(s) and number of elements (if vector) is valid
checkTypeForSPIRVExtendedInstLowering(IntrinsicInst * II,SPIRVModule * BM)1587 bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
1588   switch (II->getIntrinsicID()) {
1589   case Intrinsic::ceil:
1590   case Intrinsic::copysign:
1591   case Intrinsic::cos:
1592   case Intrinsic::exp:
1593   case Intrinsic::exp2:
1594   case Intrinsic::fabs:
1595   case Intrinsic::floor:
1596   case Intrinsic::fma:
1597   case Intrinsic::log:
1598   case Intrinsic::log10:
1599   case Intrinsic::log2:
1600   case Intrinsic::maximum:
1601   case Intrinsic::maxnum:
1602   case Intrinsic::minimum:
1603   case Intrinsic::minnum:
1604   case Intrinsic::nearbyint:
1605   case Intrinsic::pow:
1606   case Intrinsic::powi:
1607   case Intrinsic::rint:
1608   case Intrinsic::round:
1609   case Intrinsic::roundeven:
1610   case Intrinsic::sin:
1611   case Intrinsic::sqrt:
1612   case Intrinsic::trunc: {
1613     // Although some of the intrinsics above take multiple arguments, it is
1614     // sufficient to check arg 0 because the LLVM Verifier will have checked
1615     // that all floating point operands have the same type and the second
1616     // argument of powi is i32.
1617     Type *Ty = II->getType();
1618     if (II->getArgOperand(0)->getType() != Ty)
1619       return false;
1620     int NumElems = 1;
1621     if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
1622       NumElems = VecTy->getNumElements();
1623       Ty = VecTy->getElementType();
1624     }
1625     if ((!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy()) ||
1626         ((NumElems > 4) && (NumElems != 8) && (NumElems != 16))) {
1627       BM->SPIRVCK(
1628           false, InvalidFunctionCall, II->getCalledOperand()->getName().str());
1629       return false;
1630     }
1631     break;
1632   }
1633   case Intrinsic::abs: {
1634     Type *Ty = II->getType();
1635     int NumElems = 1;
1636     if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
1637       NumElems = VecTy->getNumElements();
1638       Ty = VecTy->getElementType();
1639     }
1640     if ((!Ty->isIntegerTy()) ||
1641         ((NumElems > 4) && (NumElems != 8) && (NumElems != 16))) {
1642       BM->SPIRVCK(
1643           false, InvalidFunctionCall, II->getCalledOperand()->getName().str());
1644     }
1645     break;
1646   }
1647   default:
1648     break;
1649   }
1650   return true;
1651 }
1652 
setAttrByCalledFunc(CallInst * Call)1653 void setAttrByCalledFunc(CallInst *Call) {
1654   Function *F = Call->getCalledFunction();
1655   assert(F);
1656   if (F->isIntrinsic()) {
1657     return;
1658   }
1659   Call->setCallingConv(F->getCallingConv());
1660   Call->setAttributes(F->getAttributes());
1661 }
1662 
isSPIRVBuiltinVariable(GlobalVariable * GV,SPIRVBuiltinVariableKind * Kind)1663 bool isSPIRVBuiltinVariable(GlobalVariable *GV,
1664                             SPIRVBuiltinVariableKind *Kind) {
1665   if (!GV->hasName() || !getSPIRVBuiltin(GV->getName().str(), *Kind))
1666     return false;
1667   return true;
1668 }
1669 
1670 // Variable like GlobalInvolcationId[x] -> get_global_id(x).
1671 // Variable like WorkDim -> get_work_dim().
1672 // Replace the following pattern:
1673 // %a = addrspacecast i32 addrspace(1)* @__spirv_BuiltInSubgroupMaxSize to
1674 // i32 addrspace(4)*
1675 // %b = load i32, i32 addrspace(4)* %a, align 4
1676 // %c = load i32, i32 addrspace(4)* %a, align 4
1677 // With:
1678 // %b = call spir_func i32 @_Z22get_max_sub_group_sizev()
1679 // %c = call spir_func i32 @_Z22get_max_sub_group_sizev()
1680 
1681 // And replace the following pattern:
1682 // %a = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupId to
1683 // <3 x i64> addrspace(4)*
1684 // %b = load <3 x i64>, <3 x i64> addrspace(4)* %a, align 32
1685 // %c = extractelement <3 x i64> %b, i32 idx
1686 // %d = extractelement <3 x i64> %b, i32 idx
1687 // With:
1688 // %0 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
1689 // %1 = insertelement <3 x i64> undef, i64 %0, i32 0
1690 // %2 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
1691 // %3 = insertelement <3 x i64> %1, i64 %2, i32 1
1692 // %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
1693 // %5 = insertelement <3 x i64> %3, i64 %4, i32 2
1694 // %c = extractelement <3 x i64> %5, i32 idx
1695 // %d = extractelement <3 x i64> %5, i32 idx
1696 //
1697 // Replace the following pattern:
1698 // %0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to
1699 // <3 x i64> addrspace(4)*
1700 // %1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
1701 // %2 = load i64, i64 addrspace(4)* %1, align 32
1702 // With:
1703 // %0 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
1704 // %1 = insertelement <3 x i64> undef, i64 %0, i32 0
1705 // %2 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
1706 // %3 = insertelement <3 x i64> %1, i64 %2, i32 1
1707 // %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
1708 // %5 = insertelement <3 x i64> %3, i64 %4, i32 2
1709 // %6 = extractelement <3 x i64> %5, i32 0
lowerBuiltinVariableToCall(GlobalVariable * GV,SPIRVBuiltinVariableKind Kind)1710 bool lowerBuiltinVariableToCall(GlobalVariable *GV,
1711                                 SPIRVBuiltinVariableKind Kind) {
1712   // There might be dead constant users of GV (for example, SPIRVLowerConstExpr
1713   // replaces ConstExpr uses but those ConstExprs are not deleted, since LLVM
1714   // constants are created on demand as needed and never deleted).
1715   // Remove them first!
1716   GV->removeDeadConstantUsers();
1717 
1718   Module *M = GV->getParent();
1719   LLVMContext &C = M->getContext();
1720   std::string FuncName = GV->getName().str();
1721   Type *GVTy = GV->getType()->getPointerElementType();
1722   Type *ReturnTy = GVTy;
1723   // Some SPIR-V builtin variables are translated to a function with an index
1724   // argument.
1725   bool HasIndexArg =
1726       ReturnTy->isVectorTy() &&
1727       !(BuiltInSubgroupEqMask <= Kind && Kind <= BuiltInSubgroupLtMask);
1728   if (HasIndexArg)
1729     ReturnTy = cast<VectorType>(ReturnTy)->getElementType();
1730   std::vector<Type *> ArgTy;
1731   if (HasIndexArg)
1732     ArgTy.push_back(Type::getInt32Ty(C));
1733   std::string MangledName;
1734   mangleOpenClBuiltin(FuncName, ArgTy, MangledName);
1735   Function *Func = M->getFunction(MangledName);
1736   if (!Func) {
1737     FunctionType *FT = FunctionType::get(ReturnTy, ArgTy, false);
1738     Func = Function::Create(FT, GlobalValue::ExternalLinkage, MangledName, M);
1739     Func->setCallingConv(CallingConv::SPIR_FUNC);
1740     Func->addFnAttr(Attribute::NoUnwind);
1741     Func->addFnAttr(Attribute::ReadNone);
1742     Func->addFnAttr(Attribute::WillReturn);
1743   }
1744 
1745   // Collect instructions in these containers to remove them later.
1746   std::vector<Instruction *> Loads;
1747   std::vector<Instruction *> Casts;
1748   std::vector<Instruction *> GEPs;
1749 
1750   auto Replace = [&](std::vector<Value *> Arg, Instruction *I) {
1751     auto *Call = CallInst::Create(Func, Arg, "", I);
1752     Call->takeName(I);
1753     setAttrByCalledFunc(Call);
1754     SPIRVDBG(dbgs() << "[lowerBuiltinVariableToCall] " << *I << " -> " << *Call
1755                     << '\n';)
1756     I->replaceAllUsesWith(Call);
1757   };
1758 
1759   // If HasIndexArg is true, we create 3 built-in calls and insertelement to
1760   // get 3-element vector filled with ids and replace uses of Load instruction
1761   // with this vector.
1762   // If HasIndexArg is false, the result of the Load instruction is the value
1763   // which should be replaced with the Func.
1764   // Returns true if Load was replaced, false otherwise.
1765   auto ReplaceIfLoad = [&](User *I) {
1766     auto *LD = dyn_cast<LoadInst>(I);
1767     if (!LD)
1768       return false;
1769     std::vector<Value *> Vectors;
1770     Loads.push_back(LD);
1771     if (HasIndexArg) {
1772       auto *VecTy = cast<FixedVectorType>(GVTy);
1773       Value *EmptyVec = UndefValue::get(VecTy);
1774       Vectors.push_back(EmptyVec);
1775       const DebugLoc &DLoc = LD->getDebugLoc();
1776       for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
1777         auto *Idx = ConstantInt::get(Type::getInt32Ty(C), I);
1778         auto *Call = CallInst::Create(Func, {Idx}, "", LD);
1779         if (DLoc)
1780           Call->setDebugLoc(DLoc);
1781         setAttrByCalledFunc(Call);
1782         auto *Insert = InsertElementInst::Create(Vectors.back(), Call, Idx);
1783         if (DLoc)
1784           Insert->setDebugLoc(DLoc);
1785         Insert->insertAfter(Call);
1786         Vectors.push_back(Insert);
1787       }
1788 
1789       Value *Ptr = LD->getPointerOperand();
1790 
1791       if (isa<FixedVectorType>(Ptr->getType()->getPointerElementType())) {
1792         LD->replaceAllUsesWith(Vectors.back());
1793       } else {
1794         auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
1795         assert(GEP && "Unexpected pattern!");
1796         assert(GEP->getNumIndices() == 2 && "Unexpected pattern!");
1797         Value *Idx = GEP->getOperand(2);
1798         Value *Vec = Vectors.back();
1799         auto *NewExtract = ExtractElementInst::Create(Vec, Idx);
1800         NewExtract->insertAfter(cast<Instruction>(Vec));
1801         LD->replaceAllUsesWith(NewExtract);
1802       }
1803 
1804     } else {
1805       Replace({}, LD);
1806     }
1807 
1808     return true;
1809   };
1810 
1811   // Go over the GV users, find Load and ExtractElement instructions and
1812   // replace them with the corresponding function call.
1813   for (auto *UI : GV->users()) {
1814     // There might or might not be an addrspacecast instruction.
1815     if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(UI)) {
1816       Casts.push_back(ASCast);
1817       for (auto *CastUser : ASCast->users()) {
1818         if (ReplaceIfLoad(CastUser))
1819           continue;
1820         if (auto *GEP = dyn_cast<GetElementPtrInst>(CastUser)) {
1821           GEPs.push_back(GEP);
1822           for (auto *GEPUser : GEP->users()) {
1823             if (!ReplaceIfLoad(GEPUser))
1824               llvm_unreachable("Unexpected pattern!");
1825           }
1826         } else {
1827           llvm_unreachable("Unexpected pattern!");
1828         }
1829       }
1830     } else if (!ReplaceIfLoad(UI)) {
1831       llvm_unreachable("Unexpected pattern!");
1832     }
1833   }
1834 
1835   auto Erase = [](std::vector<Instruction *> &ToErase) {
1836     for (Instruction *I : ToErase) {
1837       assert(I->hasNUses(0));
1838       I->eraseFromParent();
1839     }
1840   };
1841   // Order of erasing is important.
1842   Erase(Loads);
1843   Erase(GEPs);
1844   Erase(Casts);
1845 
1846   return true;
1847 }
1848 
lowerBuiltinVariablesToCalls(Module * M)1849 bool lowerBuiltinVariablesToCalls(Module *M) {
1850   std::vector<GlobalVariable *> WorkList;
1851   for (auto I = M->global_begin(), E = M->global_end(); I != E; ++I) {
1852     SPIRVBuiltinVariableKind Kind;
1853     if (!isSPIRVBuiltinVariable(&(*I), &Kind))
1854       continue;
1855     if (!lowerBuiltinVariableToCall(&(*I), Kind))
1856       return false;
1857     WorkList.push_back(&(*I));
1858   }
1859   for (auto &I : WorkList) {
1860     I->eraseFromParent();
1861   }
1862 
1863   return true;
1864 }
1865 
postProcessBuiltinReturningStruct(Function * F)1866 bool postProcessBuiltinReturningStruct(Function *F) {
1867   Module *M = F->getParent();
1868   LLVMContext *Context = &M->getContext();
1869   std::string Name = F->getName().str();
1870   F->setName(Name + ".old");
1871   SmallVector<Instruction *, 32> InstToRemove;
1872   for (auto *U : F->users()) {
1873     if (auto *CI = dyn_cast<CallInst>(U)) {
1874       auto *ST = cast<StoreInst>(*(CI->user_begin()));
1875       std::vector<Type *> ArgTys;
1876       getFunctionTypeParameterTypes(F->getFunctionType(), ArgTys);
1877       ArgTys.insert(ArgTys.begin(),
1878                     PointerType::get(F->getReturnType(), SPIRAS_Private));
1879       auto *NewF =
1880           getOrCreateFunction(M, Type::getVoidTy(*Context), ArgTys, Name);
1881       NewF->addParamAttr(0, Attribute::get(*Context,
1882                                            Attribute::AttrKind::StructRet,
1883                                            F->getReturnType()));
1884       NewF->setCallingConv(F->getCallingConv());
1885       auto Args = getArguments(CI);
1886       Args.insert(Args.begin(), ST->getPointerOperand());
1887       auto *NewCI = CallInst::Create(NewF, Args, CI->getName(), CI);
1888       NewCI->setCallingConv(CI->getCallingConv());
1889       InstToRemove.push_back(ST);
1890       InstToRemove.push_back(CI);
1891     }
1892   }
1893   for (auto *Inst : InstToRemove) {
1894     Inst->dropAllReferences();
1895     Inst->eraseFromParent();
1896   }
1897   F->dropAllReferences();
1898   F->eraseFromParent();
1899   return true;
1900 }
1901 
postProcessBuiltinWithArrayArguments(Function * F,StringRef DemangledName)1902 bool postProcessBuiltinWithArrayArguments(Function *F,
1903                                           StringRef DemangledName) {
1904   LLVM_DEBUG(dbgs() << "[postProcessOCLBuiltinWithArrayArguments] " << *F
1905                     << '\n');
1906   auto Attrs = F->getAttributes();
1907   auto Name = F->getName();
1908   mutateFunction(
1909       F,
1910       [=](CallInst *CI, std::vector<Value *> &Args) {
1911         auto FBegin = CI->getFunction()->begin()->getFirstInsertionPt();
1912         for (auto &I : Args) {
1913           auto *T = I->getType();
1914           if (!T->isArrayTy())
1915             continue;
1916           auto *Alloca = new AllocaInst(T, 0, "", &(*FBegin));
1917           new StoreInst(I, Alloca, false, CI);
1918           auto *Zero =
1919               ConstantInt::getNullValue(Type::getInt32Ty(T->getContext()));
1920           Value *Index[] = {Zero, Zero};
1921           I = GetElementPtrInst::CreateInBounds(T, Alloca, Index, "", CI);
1922         }
1923         return Name.str();
1924       },
1925       nullptr, &Attrs);
1926   return true;
1927 }
1928 
postProcessBuiltinsReturningStruct(Module * M,bool IsCpp)1929 bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp) {
1930   StringRef DemangledName;
1931   // postProcessBuiltinReturningStruct may remove some functions from the
1932   // module, so use make_early_inc_range
1933   for (auto &F : make_early_inc_range(M->functions())) {
1934     if (F.hasName() && F.isDeclaration()) {
1935       LLVM_DEBUG(dbgs() << "[postProcess sret] " << F << '\n');
1936       if (F.getReturnType()->isStructTy() &&
1937           oclIsBuiltin(F.getName(), DemangledName, IsCpp)) {
1938         if (!postProcessBuiltinReturningStruct(&F))
1939           return false;
1940       }
1941     }
1942   }
1943   return true;
1944 }
1945 
postProcessBuiltinsWithArrayArguments(Module * M,bool IsCpp)1946 bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp) {
1947   StringRef DemangledName;
1948   // postProcessBuiltinWithArrayArguments may remove some functions from the
1949   // module, so use make_early_inc_range
1950   for (auto &F : make_early_inc_range(M->functions())) {
1951     if (F.hasName() && F.isDeclaration()) {
1952       LLVM_DEBUG(dbgs() << "[postProcess array arg] " << F << '\n');
1953       if (hasArrayArg(&F) && oclIsBuiltin(F.getName(), DemangledName, IsCpp))
1954         if (!postProcessBuiltinWithArrayArguments(&F, DemangledName))
1955           return false;
1956     }
1957   }
1958   return true;
1959 }
1960 
1961 } // namespace SPIRV
1962 
1963 namespace {
1964 class SPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
1965 public:
SPIRVFriendlyIRMangleInfo(spv::Op OC,ArrayRef<Type * > ArgTys)1966   SPIRVFriendlyIRMangleInfo(spv::Op OC, ArrayRef<Type *> ArgTys)
1967       : OC(OC), ArgTys(ArgTys) {}
1968 
init(StringRef UniqUnmangledName)1969   void init(StringRef UniqUnmangledName) override {
1970     UnmangledName = UniqUnmangledName.str();
1971     switch (OC) {
1972     case OpConvertUToF:
1973     case OpUConvert:
1974     case OpSatConvertUToS:
1975       // Treat all arguments as unsigned
1976       addUnsignedArg(-1);
1977       break;
1978     case OpSubgroupShuffleINTEL:
1979     case OpSubgroupShuffleXorINTEL:
1980       addUnsignedArg(1);
1981       break;
1982     case OpSubgroupShuffleDownINTEL:
1983     case OpSubgroupShuffleUpINTEL:
1984       addUnsignedArg(2);
1985       break;
1986     case OpSubgroupBlockWriteINTEL:
1987       addUnsignedArg(0);
1988       addUnsignedArg(1);
1989       break;
1990     case OpSubgroupImageBlockWriteINTEL:
1991       addUnsignedArg(2);
1992       break;
1993     case OpSubgroupBlockReadINTEL:
1994       setArgAttr(0, SPIR::ATTR_CONST);
1995       addUnsignedArg(0);
1996       break;
1997     case OpAtomicUMax:
1998     case OpAtomicUMin:
1999       addUnsignedArg(0);
2000       addUnsignedArg(3);
2001       break;
2002     case OpGroupUMax:
2003     case OpGroupUMin:
2004     case OpGroupNonUniformBroadcast:
2005     case OpGroupNonUniformBallotBitCount:
2006     case OpGroupNonUniformShuffle:
2007     case OpGroupNonUniformShuffleXor:
2008     case OpGroupNonUniformShuffleUp:
2009     case OpGroupNonUniformShuffleDown:
2010       addUnsignedArg(2);
2011       break;
2012     case OpGroupNonUniformInverseBallot:
2013     case OpGroupNonUniformBallotFindLSB:
2014     case OpGroupNonUniformBallotFindMSB:
2015       addUnsignedArg(1);
2016       break;
2017     case OpGroupNonUniformBallotBitExtract:
2018       addUnsignedArg(1);
2019       addUnsignedArg(2);
2020       break;
2021     case OpGroupNonUniformIAdd:
2022     case OpGroupNonUniformFAdd:
2023     case OpGroupNonUniformIMul:
2024     case OpGroupNonUniformFMul:
2025     case OpGroupNonUniformSMin:
2026     case OpGroupNonUniformFMin:
2027     case OpGroupNonUniformSMax:
2028     case OpGroupNonUniformFMax:
2029     case OpGroupNonUniformBitwiseAnd:
2030     case OpGroupNonUniformBitwiseOr:
2031     case OpGroupNonUniformBitwiseXor:
2032     case OpGroupNonUniformLogicalAnd:
2033     case OpGroupNonUniformLogicalOr:
2034     case OpGroupNonUniformLogicalXor:
2035       addUnsignedArg(3);
2036       break;
2037     case OpGroupNonUniformUMax:
2038     case OpGroupNonUniformUMin:
2039       addUnsignedArg(2);
2040       addUnsignedArg(3);
2041       break;
2042     default:;
2043       // No special handling is needed
2044     }
2045   }
2046 
2047 private:
2048   spv::Op OC;
2049   ArrayRef<Type *> ArgTys;
2050 };
2051 class OpenCLStdToSPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
2052 public:
OpenCLStdToSPIRVFriendlyIRMangleInfo(OCLExtOpKind ExtOpId,ArrayRef<Type * > ArgTys,Type * RetTy)2053   OpenCLStdToSPIRVFriendlyIRMangleInfo(OCLExtOpKind ExtOpId,
2054                                        ArrayRef<Type *> ArgTys, Type *RetTy)
2055       : ExtOpId(ExtOpId), ArgTys(ArgTys) {
2056 
2057     std::string Postfix = "";
2058     if (needRetTypePostfix())
2059       Postfix = kSPIRVPostfix::Divider + getPostfixForReturnType(RetTy, true);
2060 
2061     UnmangledName = getSPIRVExtFuncName(SPIRVEIS_OpenCL, ExtOpId, Postfix);
2062   }
2063 
needRetTypePostfix()2064   bool needRetTypePostfix() {
2065     switch (ExtOpId) {
2066     case OpenCLLIB::Vload_half:
2067     case OpenCLLIB::Vload_halfn:
2068     case OpenCLLIB::Vloada_halfn:
2069     case OpenCLLIB::Vloadn:
2070       return true;
2071     default:
2072       return false;
2073     }
2074   }
2075 
init(StringRef)2076   void init(StringRef) override {
2077     switch (ExtOpId) {
2078     case OpenCLLIB::UAbs:
2079     case OpenCLLIB::UAbs_diff:
2080     case OpenCLLIB::UAdd_sat:
2081     case OpenCLLIB::UHadd:
2082     case OpenCLLIB::URhadd:
2083     case OpenCLLIB::UClamp:
2084     case OpenCLLIB::UMad_hi:
2085     case OpenCLLIB::UMad_sat:
2086     case OpenCLLIB::UMax:
2087     case OpenCLLIB::UMin:
2088     case OpenCLLIB::UMul_hi:
2089     case OpenCLLIB::USub_sat:
2090     case OpenCLLIB::U_Upsample:
2091     case OpenCLLIB::UMad24:
2092     case OpenCLLIB::UMul24:
2093       // Treat all arguments as unsigned
2094       addUnsignedArg(-1);
2095       break;
2096     case OpenCLLIB::S_Upsample:
2097       addUnsignedArg(1);
2098       break;
2099     default:;
2100       // No special handling is needed
2101     }
2102   }
2103 
2104 private:
2105   OCLExtOpKind ExtOpId;
2106   ArrayRef<Type *> ArgTys;
2107 };
2108 } // namespace
2109 
2110 namespace SPIRV {
getSPIRVFriendlyIRFunctionName(OCLExtOpKind ExtOpId,ArrayRef<Type * > ArgTys,Type * RetTy)2111 std::string getSPIRVFriendlyIRFunctionName(OCLExtOpKind ExtOpId,
2112                                            ArrayRef<Type *> ArgTys,
2113                                            Type *RetTy) {
2114   OpenCLStdToSPIRVFriendlyIRMangleInfo MangleInfo(ExtOpId, ArgTys, RetTy);
2115   return mangleBuiltin(MangleInfo.getUnmangledName(), ArgTys, &MangleInfo);
2116 }
2117 
getSPIRVFriendlyIRFunctionName(const std::string & UniqName,spv::Op OC,ArrayRef<Type * > ArgTys)2118 std::string getSPIRVFriendlyIRFunctionName(const std::string &UniqName,
2119                                            spv::Op OC,
2120                                            ArrayRef<Type *> ArgTys) {
2121   SPIRVFriendlyIRMangleInfo MangleInfo(OC, ArgTys);
2122   return mangleBuiltin(UniqName, ArgTys, &MangleInfo);
2123 }
2124 
2125 } // namespace SPIRV
2126