1 //===- SPIRVUtil.cpp - SPIR-V Utilities -------------------------*- C++ -*-===//
2 //
3 // The LLVM/SPIRV Translator
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining a
11 // copy of this software and associated documentation files (the "Software"),
12 // to deal with the Software without restriction, including without limitation
13 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
14 // and/or sell copies of the Software, and to permit persons to whom the
15 // Software is furnished to do so, subject to the following conditions:
16 //
17 // Redistributions of source code must retain the above copyright notice,
18 // this list of conditions and the following disclaimers.
19 // Redistributions in binary form must reproduce the above copyright notice,
20 // this list of conditions and the following disclaimers in the documentation
21 // and/or other materials provided with the distribution.
22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its
23 // contributors may be used to endorse or promote products derived from this
24 // Software without specific prior written permission.
25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31 // THE SOFTWARE.
32 //
33 //===----------------------------------------------------------------------===//
34 /// \file
35 ///
36 /// This file defines utility classes and functions shared by SPIR-V
37 /// reader/writer.
38 ///
39 //===----------------------------------------------------------------------===//
40
41 #include "FunctionDescriptor.h"
42 #include "ManglingUtils.h"
43 #include "NameMangleAPI.h"
44 #include "OCLUtil.h"
45 #include "ParameterType.h"
46 #include "SPIRVInternal.h"
47 #include "SPIRVMDWalker.h"
48 #include "libSPIRV/SPIRVDecorate.h"
49 #include "libSPIRV/SPIRVValue.h"
50
51 #include "llvm/ADT/StringSwitch.h"
52 #include "llvm/Bitcode/BitcodeWriter.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/IntrinsicInst.h"
55 #include "llvm/Support/CommandLine.h"
56 #include "llvm/Support/Debug.h"
57 #include "llvm/Support/ErrorHandling.h"
58 #include "llvm/Support/FileSystem.h"
59 #include "llvm/Support/ToolOutputFile.h"
60
61 #include <functional>
62 #include <sstream>
63
64 #define DEBUG_TYPE "spirv"
65
66 namespace SPIRV {
67
68 #ifdef _SPIRV_SUPPORT_TEXT_FMT
69 cl::opt<bool, true>
70 UseTextFormat("spirv-text",
71 cl::desc("Use text format for SPIR-V for debugging purpose"),
72 cl::location(SPIRVUseTextFormat));
73 #endif
74
75 #ifdef _SPIRVDBG
76 cl::opt<bool, true> EnableDbgOutput("spirv-debug",
77 cl::desc("Enable SPIR-V debug output"),
78 cl::location(SPIRVDbgEnable));
79 #endif
80
isSupportedTriple(Triple T)81 bool isSupportedTriple(Triple T) { return T.isSPIR(); }
82
addFnAttr(CallInst * Call,Attribute::AttrKind Attr)83 void addFnAttr(CallInst *Call, Attribute::AttrKind Attr) {
84 Call->addAttribute(AttributeList::FunctionIndex, Attr);
85 }
86
removeFnAttr(CallInst * Call,Attribute::AttrKind Attr)87 void removeFnAttr(CallInst *Call, Attribute::AttrKind Attr) {
88 Call->removeAttribute(AttributeList::FunctionIndex, Attr);
89 }
90
removeCast(Value * V)91 Value *removeCast(Value *V) {
92 auto Cast = dyn_cast<ConstantExpr>(V);
93 if (Cast && Cast->isCast()) {
94 return removeCast(Cast->getOperand(0));
95 }
96 if (auto Cast = dyn_cast<CastInst>(V))
97 return removeCast(Cast->getOperand(0));
98 return V;
99 }
100
saveLLVMModule(Module * M,const std::string & OutputFile)101 void saveLLVMModule(Module *M, const std::string &OutputFile) {
102 std::error_code EC;
103 ToolOutputFile Out(OutputFile.c_str(), EC, sys::fs::OF_None);
104 if (EC) {
105 SPIRVDBG(errs() << "Fails to open output file: " << EC.message();)
106 return;
107 }
108
109 WriteBitcodeToFile(*M, Out.os());
110 Out.keep();
111 }
112
mapLLVMTypeToOCLType(const Type * Ty,bool Signed)113 std::string mapLLVMTypeToOCLType(const Type *Ty, bool Signed) {
114 if (Ty->isHalfTy())
115 return "half";
116 if (Ty->isFloatTy())
117 return "float";
118 if (Ty->isDoubleTy())
119 return "double";
120 if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
121 std::string SignPrefix;
122 std::string Stem;
123 if (!Signed)
124 SignPrefix = "u";
125 switch (IntTy->getIntegerBitWidth()) {
126 case 8:
127 Stem = "char";
128 break;
129 case 16:
130 Stem = "short";
131 break;
132 case 32:
133 Stem = "int";
134 break;
135 case 64:
136 Stem = "long";
137 break;
138 default:
139 Stem = "invalid_type";
140 break;
141 }
142 return SignPrefix + Stem;
143 }
144 if (auto VecTy = dyn_cast<FixedVectorType>(Ty)) {
145 Type *EleTy = VecTy->getElementType();
146 unsigned Size = VecTy->getNumElements();
147 std::stringstream Ss;
148 Ss << mapLLVMTypeToOCLType(EleTy, Signed) << Size;
149 return Ss.str();
150 }
151 return "invalid_type";
152 }
153
mapSPIRVTypeToOCLType(SPIRVType * Ty,bool Signed)154 std::string mapSPIRVTypeToOCLType(SPIRVType *Ty, bool Signed) {
155 if (Ty->isTypeFloat()) {
156 auto W = Ty->getBitWidth();
157 switch (W) {
158 case 16:
159 return "half";
160 case 32:
161 return "float";
162 case 64:
163 return "double";
164 default:
165 assert(0 && "Invalid floating pointer type");
166 return std::string("float") + W + "_t";
167 }
168 }
169 if (Ty->isTypeInt()) {
170 std::string SignPrefix;
171 std::string Stem;
172 if (!Signed)
173 SignPrefix = "u";
174 auto W = Ty->getBitWidth();
175 switch (W) {
176 case 8:
177 Stem = "char";
178 break;
179 case 16:
180 Stem = "short";
181 break;
182 case 32:
183 Stem = "int";
184 break;
185 case 64:
186 Stem = "long";
187 break;
188 default:
189 llvm_unreachable("Invalid integer type");
190 Stem = std::string("int") + W + "_t";
191 break;
192 }
193 return SignPrefix + Stem;
194 }
195 if (Ty->isTypeVector()) {
196 auto EleTy = Ty->getVectorComponentType();
197 auto Size = Ty->getVectorComponentCount();
198 std::stringstream Ss;
199 Ss << mapSPIRVTypeToOCLType(EleTy, Signed) << Size;
200 return Ss.str();
201 }
202 llvm_unreachable("Invalid type");
203 return "unknown_type";
204 }
205
getOrCreateOpaquePtrType(Module * M,const std::string & Name,unsigned AddrSpace)206 PointerType *getOrCreateOpaquePtrType(Module *M, const std::string &Name,
207 unsigned AddrSpace) {
208 auto OpaqueType = StructType::getTypeByName(M->getContext(), Name);
209 if (!OpaqueType)
210 OpaqueType = StructType::create(M->getContext(), Name);
211 return PointerType::get(OpaqueType, AddrSpace);
212 }
213
getSamplerType(Module * M)214 PointerType *getSamplerType(Module *M) {
215 return getOrCreateOpaquePtrType(M, getSPIRVTypeName(kSPIRVTypeName::Sampler),
216 SPIRAS_Constant);
217 }
218
getPipeStorageType(Module * M)219 PointerType *getPipeStorageType(Module *M) {
220 return getOrCreateOpaquePtrType(
221 M, getSPIRVTypeName(kSPIRVTypeName::PipeStorage), SPIRAS_Constant);
222 }
223
getSPIRVOpaquePtrType(Module * M,Op OC)224 PointerType *getSPIRVOpaquePtrType(Module *M, Op OC) {
225 std::string Name = getSPIRVTypeName(SPIRVOpaqueTypeOpCodeMap::rmap(OC));
226 return getOrCreateOpaquePtrType(M, Name, getOCLOpaqueTypeAddrSpace(OC));
227 }
228
getFunctionTypeParameterTypes(llvm::FunctionType * FT,std::vector<Type * > & ArgTys)229 void getFunctionTypeParameterTypes(llvm::FunctionType *FT,
230 std::vector<Type *> &ArgTys) {
231 for (auto I = FT->param_begin(), E = FT->param_end(); I != E; ++I) {
232 ArgTys.push_back(*I);
233 }
234 }
235
isVoidFuncTy(FunctionType * FT)236 bool isVoidFuncTy(FunctionType *FT) {
237 return FT->getReturnType()->isVoidTy() && FT->getNumParams() == 0;
238 }
239
isPointerToOpaqueStructType(llvm::Type * Ty)240 bool isPointerToOpaqueStructType(llvm::Type *Ty) {
241 if (auto PT = dyn_cast<PointerType>(Ty))
242 if (auto ST = dyn_cast<StructType>(PT->getElementType()))
243 if (ST->isOpaque())
244 return true;
245 return false;
246 }
247
isPointerToOpaqueStructType(llvm::Type * Ty,const std::string & Name)248 bool isPointerToOpaqueStructType(llvm::Type *Ty, const std::string &Name) {
249 if (auto PT = dyn_cast<PointerType>(Ty))
250 if (auto ST = dyn_cast<StructType>(PT->getElementType()))
251 if (ST->isOpaque() && ST->getName() == Name)
252 return true;
253 return false;
254 }
255
isOCLImageType(llvm::Type * Ty,StringRef * Name)256 bool isOCLImageType(llvm::Type *Ty, StringRef *Name) {
257 if (auto PT = dyn_cast<PointerType>(Ty))
258 if (auto ST = dyn_cast<StructType>(PT->getElementType()))
259 if (ST->isOpaque()) {
260 auto FullName = ST->getName();
261 if (FullName.find(kSPR2TypeName::ImagePrefix) == 0) {
262 if (Name)
263 *Name = FullName.drop_front(strlen(kSPR2TypeName::OCLPrefix));
264 return true;
265 }
266 }
267 return false;
268 }
269
270 /// \param BaseTyName is the type Name as in spirv.BaseTyName.Postfixes
271 /// \param Postfix contains postfixes extracted from the SPIR-V image
272 /// type Name as spirv.BaseTyName.Postfixes.
isSPIRVType(llvm::Type * Ty,StringRef BaseTyName,StringRef * Postfix)273 bool isSPIRVType(llvm::Type *Ty, StringRef BaseTyName, StringRef *Postfix) {
274 if (auto PT = dyn_cast<PointerType>(Ty))
275 if (auto ST = dyn_cast<StructType>(PT->getElementType()))
276 if (ST->isOpaque()) {
277 auto FullName = ST->getName();
278 std::string N =
279 std::string(kSPIRVTypeName::PrefixAndDelim) + BaseTyName.str();
280 if (FullName != N)
281 N = N + kSPIRVTypeName::Delimiter;
282 if (FullName.startswith(N)) {
283 if (Postfix)
284 *Postfix = FullName.drop_front(N.size());
285 return true;
286 }
287 }
288 return false;
289 }
290
getOrCreateFunction(Module * M,Type * RetTy,ArrayRef<Type * > ArgTypes,StringRef Name,BuiltinFuncMangleInfo * Mangle,AttributeList * Attrs,bool TakeName)291 Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
292 StringRef Name, BuiltinFuncMangleInfo *Mangle,
293 AttributeList *Attrs, bool TakeName) {
294 std::string MangledName{Name};
295 bool IsVarArg = false;
296 if (Mangle) {
297 MangledName = mangleBuiltin(Name, ArgTypes, Mangle);
298 IsVarArg = 0 <= Mangle->getVarArg();
299 if (IsVarArg)
300 ArgTypes = ArgTypes.slice(0, Mangle->getVarArg());
301 }
302 FunctionType *FT = FunctionType::get(RetTy, ArgTypes, IsVarArg);
303 Function *F = M->getFunction(MangledName);
304 if (!TakeName && F && F->getFunctionType() != FT && Mangle != nullptr) {
305 std::string S;
306 raw_string_ostream SS(S);
307 SS << "Error: Attempt to redefine function: " << *F << " => " << *FT
308 << '\n';
309 report_fatal_error(SS.str(), false);
310 }
311 if (!F || F->getFunctionType() != FT) {
312 auto NewF =
313 Function::Create(FT, GlobalValue::ExternalLinkage, MangledName, M);
314 if (F && TakeName) {
315 NewF->takeName(F);
316 LLVM_DEBUG(
317 dbgs() << "[getOrCreateFunction] Warning: taking function Name\n");
318 }
319 if (NewF->getName() != MangledName) {
320 LLVM_DEBUG(
321 dbgs() << "[getOrCreateFunction] Warning: function Name changed\n");
322 }
323 LLVM_DEBUG(dbgs() << "[getOrCreateFunction] ";
324 if (F) dbgs() << *F << " => "; dbgs() << *NewF << '\n';);
325 F = NewF;
326 F->setCallingConv(CallingConv::SPIR_FUNC);
327 if (Attrs)
328 F->setAttributes(*Attrs);
329 }
330 return F;
331 }
332
getArguments(CallInst * CI,unsigned Start,unsigned End)333 std::vector<Value *> getArguments(CallInst *CI, unsigned Start, unsigned End) {
334 std::vector<Value *> Args;
335 if (End == 0)
336 End = CI->getNumArgOperands();
337 for (; Start != End; ++Start) {
338 Args.push_back(CI->getArgOperand(Start));
339 }
340 return Args;
341 }
342
getArgAsInt(CallInst * CI,unsigned I)343 uint64_t getArgAsInt(CallInst *CI, unsigned I) {
344 return cast<ConstantInt>(CI->getArgOperand(I))->getZExtValue();
345 }
346
getArgAsScope(CallInst * CI,unsigned I)347 Scope getArgAsScope(CallInst *CI, unsigned I) {
348 return static_cast<Scope>(getArgAsInt(CI, I));
349 }
350
getArgAsDecoration(CallInst * CI,unsigned I)351 Decoration getArgAsDecoration(CallInst *CI, unsigned I) {
352 return static_cast<Decoration>(getArgAsInt(CI, I));
353 }
354
decorateSPIRVFunction(const std::string & S)355 std::string decorateSPIRVFunction(const std::string &S) {
356 return std::string(kSPIRVName::Prefix) + S + kSPIRVName::Postfix;
357 }
358
undecorateSPIRVFunction(StringRef S)359 StringRef undecorateSPIRVFunction(StringRef S) {
360 assert(S.find(kSPIRVName::Prefix) == 0);
361 const size_t Start = strlen(kSPIRVName::Prefix);
362 auto End = S.rfind(kSPIRVName::Postfix);
363 return S.substr(Start, End - Start);
364 }
365
prefixSPIRVName(const std::string & S)366 std::string prefixSPIRVName(const std::string &S) {
367 return std::string(kSPIRVName::Prefix) + S;
368 }
369
dePrefixSPIRVName(StringRef R,SmallVectorImpl<StringRef> & Postfix)370 StringRef dePrefixSPIRVName(StringRef R, SmallVectorImpl<StringRef> &Postfix) {
371 const size_t Start = strlen(kSPIRVName::Prefix);
372 if (!R.startswith(kSPIRVName::Prefix))
373 return R;
374 R = R.drop_front(Start);
375 R.split(Postfix, "_", -1, false);
376 auto Name = Postfix.front();
377 Postfix.erase(Postfix.begin());
378 return Name;
379 }
380
getSPIRVFuncName(Op OC,StringRef PostFix)381 std::string getSPIRVFuncName(Op OC, StringRef PostFix) {
382 return prefixSPIRVName(getName(OC) + PostFix.str());
383 }
384
getSPIRVFuncName(Op OC,const Type * PRetTy,bool IsSigned)385 std::string getSPIRVFuncName(Op OC, const Type *PRetTy, bool IsSigned) {
386 return prefixSPIRVName(getName(OC) + kSPIRVPostfix::Divider +
387 getPostfixForReturnType(PRetTy, IsSigned));
388 }
389
getSPIRVFuncName(SPIRVBuiltinVariableKind BVKind)390 std::string getSPIRVFuncName(SPIRVBuiltinVariableKind BVKind) {
391 return prefixSPIRVName(getName(BVKind));
392 }
393
getSPIRVExtFuncName(SPIRVExtInstSetKind Set,unsigned ExtOp,StringRef PostFix)394 std::string getSPIRVExtFuncName(SPIRVExtInstSetKind Set, unsigned ExtOp,
395 StringRef PostFix) {
396 std::string ExtOpName;
397 switch (Set) {
398 default:
399 llvm_unreachable("invalid extended instruction set");
400 ExtOpName = "unknown";
401 break;
402 case SPIRVEIS_OpenCL:
403 ExtOpName = getName(static_cast<OCLExtOpKind>(ExtOp));
404 break;
405 }
406 return prefixSPIRVName(SPIRVExtSetShortNameMap::map(Set) + '_' + ExtOpName +
407 PostFix.str());
408 }
409
mapPostfixToDecorate(StringRef Postfix,SPIRVEntry * Target)410 SPIRVDecorate *mapPostfixToDecorate(StringRef Postfix, SPIRVEntry *Target) {
411 if (Postfix == kSPIRVPostfix::Sat)
412 return new SPIRVDecorate(spv::DecorationSaturatedConversion, Target);
413
414 if (Postfix.startswith(kSPIRVPostfix::Rt))
415 return new SPIRVDecorate(spv::DecorationFPRoundingMode, Target,
416 map<SPIRVFPRoundingModeKind>(Postfix.str()));
417
418 return nullptr;
419 }
420
addDecorations(SPIRVValue * Target,const SmallVectorImpl<std::string> & Decs)421 SPIRVValue *addDecorations(SPIRVValue *Target,
422 const SmallVectorImpl<std::string> &Decs) {
423 for (auto &I : Decs)
424 if (auto Dec = mapPostfixToDecorate(I, Target))
425 Target->addDecorate(Dec);
426 return Target;
427 }
428
getPostfix(Decoration Dec,unsigned Value)429 std::string getPostfix(Decoration Dec, unsigned Value) {
430 switch (Dec) {
431 default:
432 llvm_unreachable("not implemented");
433 return "unknown";
434 case spv::DecorationSaturatedConversion:
435 return kSPIRVPostfix::Sat;
436 case spv::DecorationFPRoundingMode:
437 return rmap<std::string>(static_cast<SPIRVFPRoundingModeKind>(Value));
438 }
439 }
440
getPostfixForReturnType(CallInst * CI,bool IsSigned)441 std::string getPostfixForReturnType(CallInst *CI, bool IsSigned) {
442 return getPostfixForReturnType(CI->getType(), IsSigned);
443 }
444
getPostfixForReturnType(const Type * PRetTy,bool IsSigned)445 std::string getPostfixForReturnType(const Type *PRetTy, bool IsSigned) {
446 return std::string(kSPIRVPostfix::Return) +
447 mapLLVMTypeToOCLType(PRetTy, IsSigned);
448 }
449
450 // Enqueue kernel, kernel query, pipe and address space cast built-ins
451 // are not mangled.
isNonMangledOCLBuiltin(StringRef Name)452 bool isNonMangledOCLBuiltin(StringRef Name) {
453 if (!Name.startswith("__"))
454 return false;
455
456 return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) ||
457 isPipeOrAddressSpaceCastBI(Name.drop_front(2));
458 }
459
getSPIRVFuncOC(StringRef S,SmallVectorImpl<std::string> * Dec)460 Op getSPIRVFuncOC(StringRef S, SmallVectorImpl<std::string> *Dec) {
461 Op OC;
462 SmallVector<StringRef, 2> Postfix;
463 StringRef Name;
464 if (!oclIsBuiltin(S, Name))
465 Name = S;
466 StringRef R(Name);
467 if ((!Name.startswith(kSPIRVName::Prefix) && !isNonMangledOCLBuiltin(S)) ||
468 !getByName(dePrefixSPIRVName(R, Postfix).str(), OC)) {
469 return OpNop;
470 }
471 if (Dec)
472 for (auto &I : Postfix)
473 Dec->push_back(I.str());
474 return OC;
475 }
476
getSPIRVBuiltin(const std::string & OrigName,spv::BuiltIn & B)477 bool getSPIRVBuiltin(const std::string &OrigName, spv::BuiltIn &B) {
478 SmallVector<StringRef, 2> Postfix;
479 StringRef R(OrigName);
480 R = dePrefixSPIRVName(R, Postfix);
481 if (!Postfix.empty())
482 return false;
483 return getByName(R.str(), B);
484 }
485
486 // Demangled name is a substring of the name. The DemangledName is updated only
487 // if true is returned
oclIsBuiltin(StringRef Name,StringRef & DemangledName,bool IsCpp)488 bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
489 if (Name == "printf") {
490 DemangledName = Name;
491 return true;
492 }
493 if (isNonMangledOCLBuiltin(Name)) {
494 DemangledName = Name.drop_front(2);
495 return true;
496 }
497 if (!Name.startswith("_Z"))
498 return false;
499 // OpenCL C++ built-ins are declared in cl namespace.
500 // TODO: consider using 'St' abbriviation for cl namespace mangling.
501 // Similar to ::std:: in C++.
502 if (IsCpp) {
503 if (!Name.startswith("_ZN"))
504 return false;
505 // Skip CV and ref qualifiers.
506 size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3);
507 // All built-ins are in the ::cl:: namespace.
508 if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv")
509 return false;
510 size_t DemangledNameLenStart = NameSpaceStart + 11;
511 size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
512 size_t Len = 0;
513 Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart)
514 .getAsInteger(10, Len);
515 DemangledName = Name.substr(Start, Len);
516 } else {
517 size_t Start = Name.find_first_not_of("0123456789", 2);
518 size_t Len = 0;
519 Name.substr(2, Start - 2).getAsInteger(10, Len);
520 DemangledName = Name.substr(Start, Len);
521 }
522 return true;
523 }
524
525 // Check if a mangled type Name is unsigned
isMangledTypeUnsigned(char Mangled)526 bool isMangledTypeUnsigned(char Mangled) {
527 return Mangled == 'h' /* uchar */
528 || Mangled == 't' /* ushort */
529 || Mangled == 'j' /* uint */
530 || Mangled == 'm' /* ulong */;
531 }
532
533 // Check if a mangled type Name is signed
isMangledTypeSigned(char Mangled)534 bool isMangledTypeSigned(char Mangled) {
535 return Mangled == 'c' /* char */
536 || Mangled == 'a' /* signed char */
537 || Mangled == 's' /* short */
538 || Mangled == 'i' /* int */
539 || Mangled == 'l' /* long */;
540 }
541
542 // Check if a mangled type Name is floating point (excludes half)
isMangledTypeFP(char Mangled)543 bool isMangledTypeFP(char Mangled) {
544 return Mangled == 'f' /* float */
545 || Mangled == 'd'; /* double */
546 }
547
548 // Check if a mangled type Name is half
isMangledTypeHalf(std::string Mangled)549 bool isMangledTypeHalf(std::string Mangled) {
550 return Mangled == "Dh"; /* half */
551 }
552
eraseSubstitutionFromMangledName(std::string & MangledName)553 void eraseSubstitutionFromMangledName(std::string &MangledName) {
554 auto Len = MangledName.length();
555 while (Len >= 2 && MangledName.substr(Len - 2, 2) == "S_") {
556 Len -= 2;
557 MangledName.erase(Len, 2);
558 }
559 }
560
lastFuncParamType(StringRef MangledName)561 ParamType lastFuncParamType(StringRef MangledName) {
562 std::string Copy(MangledName);
563 eraseSubstitutionFromMangledName(Copy);
564 char Mangled = Copy.back();
565 std::string Mangled2 = Copy.substr(Copy.size() - 2);
566
567 if (isMangledTypeFP(Mangled) || isMangledTypeHalf(Mangled2)) {
568 return ParamType::FLOAT;
569 } else if (isMangledTypeUnsigned(Mangled)) {
570 return ParamType::UNSIGNED;
571 } else if (isMangledTypeSigned(Mangled)) {
572 return ParamType::SIGNED;
573 }
574
575 return ParamType::UNKNOWN;
576 }
577
578 // Check if the last argument is signed
isLastFuncParamSigned(StringRef MangledName)579 bool isLastFuncParamSigned(StringRef MangledName) {
580 return lastFuncParamType(MangledName) == ParamType::SIGNED;
581 }
582
583 // Check if a mangled function Name contains unsigned atomic type
containsUnsignedAtomicType(StringRef Name)584 bool containsUnsignedAtomicType(StringRef Name) {
585 auto Loc = Name.find(kMangledName::AtomicPrefixIncoming);
586 if (Loc == StringRef::npos)
587 return false;
588 return isMangledTypeUnsigned(
589 Name[Loc + strlen(kMangledName::AtomicPrefixIncoming)]);
590 }
591
isFunctionPointerType(Type * T)592 bool isFunctionPointerType(Type *T) {
593 if (isa<PointerType>(T) && isa<FunctionType>(T->getPointerElementType())) {
594 return true;
595 }
596 return false;
597 }
598
hasFunctionPointerArg(Function * F,Function::arg_iterator & AI)599 bool hasFunctionPointerArg(Function *F, Function::arg_iterator &AI) {
600 AI = F->arg_begin();
601 for (auto AE = F->arg_end(); AI != AE; ++AI) {
602 LLVM_DEBUG(dbgs() << "[hasFuncPointerArg] " << *AI << '\n');
603 if (isFunctionPointerType(AI->getType())) {
604 return true;
605 }
606 }
607 return false;
608 }
609
castToVoidFuncPtr(Function * F)610 Constant *castToVoidFuncPtr(Function *F) {
611 auto T = getVoidFuncPtrType(F->getParent());
612 return ConstantExpr::getBitCast(F, T);
613 }
614
hasArrayArg(Function * F)615 bool hasArrayArg(Function *F) {
616 for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
617 LLVM_DEBUG(dbgs() << "[hasArrayArg] " << *I << '\n');
618 if (I->getType()->isArrayTy()) {
619 return true;
620 }
621 }
622 return false;
623 }
624
mutateCallInst(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &)> ArgMutate,BuiltinFuncMangleInfo * Mangle,AttributeList * Attrs,bool TakeFuncName)625 CallInst *mutateCallInst(
626 Module *M, CallInst *CI,
627 std::function<std::string(CallInst *, std::vector<Value *> &)> ArgMutate,
628 BuiltinFuncMangleInfo *Mangle, AttributeList *Attrs, bool TakeFuncName) {
629 LLVM_DEBUG(dbgs() << "[mutateCallInst] " << *CI);
630
631 auto Args = getArguments(CI);
632 auto NewName = ArgMutate(CI, Args);
633 std::string InstName;
634 if (!CI->getType()->isVoidTy() && CI->hasName()) {
635 InstName = CI->getName().str();
636 CI->setName(InstName + ".old");
637 }
638 auto NewCI = addCallInst(M, NewName, CI->getType(), Args, Attrs, CI, Mangle,
639 InstName, TakeFuncName);
640 NewCI->setDebugLoc(CI->getDebugLoc());
641 LLVM_DEBUG(dbgs() << " => " << *NewCI << '\n');
642 CI->replaceAllUsesWith(NewCI);
643 CI->eraseFromParent();
644 return NewCI;
645 }
646
mutateCallInst(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &,Type * & RetTy)> ArgMutate,std::function<Instruction * (CallInst *)> RetMutate,BuiltinFuncMangleInfo * Mangle,AttributeList * Attrs,bool TakeFuncName)647 Instruction *mutateCallInst(
648 Module *M, CallInst *CI,
649 std::function<std::string(CallInst *, std::vector<Value *> &, Type *&RetTy)>
650 ArgMutate,
651 std::function<Instruction *(CallInst *)> RetMutate,
652 BuiltinFuncMangleInfo *Mangle, AttributeList *Attrs, bool TakeFuncName) {
653 LLVM_DEBUG(dbgs() << "[mutateCallInst] " << *CI);
654
655 auto Args = getArguments(CI);
656 Type *RetTy = CI->getType();
657 auto NewName = ArgMutate(CI, Args, RetTy);
658 StringRef InstName = CI->getName();
659 auto NewCI = addCallInst(M, NewName, RetTy, Args, Attrs, CI, Mangle, InstName,
660 TakeFuncName);
661 auto NewI = RetMutate(NewCI);
662 NewI->takeName(CI);
663 NewI->setDebugLoc(CI->getDebugLoc());
664 LLVM_DEBUG(dbgs() << " => " << *NewI << '\n');
665 if (!CI->getType()->isVoidTy())
666 CI->replaceAllUsesWith(NewI);
667 CI->eraseFromParent();
668 return NewI;
669 }
670
mutateFunction(Function * F,std::function<std::string (CallInst *,std::vector<Value * > &)> ArgMutate,BuiltinFuncMangleInfo * Mangle,AttributeList * Attrs,bool TakeFuncName)671 void mutateFunction(
672 Function *F,
673 std::function<std::string(CallInst *, std::vector<Value *> &)> ArgMutate,
674 BuiltinFuncMangleInfo *Mangle, AttributeList *Attrs, bool TakeFuncName) {
675 auto M = F->getParent();
676 for (auto I = F->user_begin(), E = F->user_end(); I != E;) {
677 if (auto CI = dyn_cast<CallInst>(*I++))
678 mutateCallInst(M, CI, ArgMutate, Mangle, Attrs, TakeFuncName);
679 }
680 if (F->use_empty())
681 F->eraseFromParent();
682 }
683
mutateCallInstSPIRV(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &)> ArgMutate,AttributeList * Attrs)684 CallInst *mutateCallInstSPIRV(
685 Module *M, CallInst *CI,
686 std::function<std::string(CallInst *, std::vector<Value *> &)> ArgMutate,
687 AttributeList *Attrs) {
688 BuiltinFuncMangleInfo BtnInfo;
689 return mutateCallInst(M, CI, ArgMutate, &BtnInfo, Attrs);
690 }
691
mutateCallInstSPIRV(Module * M,CallInst * CI,std::function<std::string (CallInst *,std::vector<Value * > &,Type * & RetTy)> ArgMutate,std::function<Instruction * (CallInst *)> RetMutate,AttributeList * Attrs)692 Instruction *mutateCallInstSPIRV(
693 Module *M, CallInst *CI,
694 std::function<std::string(CallInst *, std::vector<Value *> &, Type *&RetTy)>
695 ArgMutate,
696 std::function<Instruction *(CallInst *)> RetMutate, AttributeList *Attrs) {
697 BuiltinFuncMangleInfo BtnInfo;
698 return mutateCallInst(M, CI, ArgMutate, RetMutate, &BtnInfo, Attrs);
699 }
700
addCallInst(Module * M,StringRef FuncName,Type * RetTy,ArrayRef<Value * > Args,AttributeList * Attrs,Instruction * Pos,BuiltinFuncMangleInfo * Mangle,StringRef InstName,bool TakeFuncName)701 CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy,
702 ArrayRef<Value *> Args, AttributeList *Attrs,
703 Instruction *Pos, BuiltinFuncMangleInfo *Mangle,
704 StringRef InstName, bool TakeFuncName) {
705
706 auto F = getOrCreateFunction(M, RetTy, getTypes(Args), FuncName, Mangle,
707 Attrs, TakeFuncName);
708 // Cannot assign a Name to void typed values
709 auto CI = CallInst::Create(F, Args, RetTy->isVoidTy() ? "" : InstName, Pos);
710 CI->setCallingConv(F->getCallingConv());
711 CI->setAttributes(F->getAttributes());
712 return CI;
713 }
714
addCallInstSPIRV(Module * M,StringRef FuncName,Type * RetTy,ArrayRef<Value * > Args,AttributeList * Attrs,Instruction * Pos,StringRef InstName)715 CallInst *addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy,
716 ArrayRef<Value *> Args, AttributeList *Attrs,
717 Instruction *Pos, StringRef InstName) {
718 BuiltinFuncMangleInfo BtnInfo;
719 return addCallInst(M, FuncName, RetTy, Args, Attrs, Pos, &BtnInfo, InstName);
720 }
721
isValidVectorSize(unsigned I)722 bool isValidVectorSize(unsigned I) {
723 return I == 2 || I == 3 || I == 4 || I == 8 || I == 16;
724 }
725
addVector(Instruction * InsPos,ValueVecRange Range)726 Value *addVector(Instruction *InsPos, ValueVecRange Range) {
727 size_t VecSize = Range.second - Range.first;
728 if (VecSize == 1)
729 return *Range.first;
730 assert(isValidVectorSize(VecSize) && "Invalid vector size");
731 IRBuilder<> Builder(InsPos);
732 auto Vec = Builder.CreateVectorSplat(VecSize, *Range.first);
733 unsigned Index = 1;
734 for (++Range.first; Range.first != Range.second; ++Range.first, ++Index)
735 Vec = Builder.CreateInsertElement(
736 Vec, *Range.first,
737 ConstantInt::get(Type::getInt32Ty(InsPos->getContext()), Index, false));
738 return Vec;
739 }
740
makeVector(Instruction * InsPos,std::vector<Value * > & Ops,ValueVecRange Range)741 void makeVector(Instruction *InsPos, std::vector<Value *> &Ops,
742 ValueVecRange Range) {
743 auto Vec = addVector(InsPos, Range);
744 Ops.erase(Range.first, Range.second);
745 Ops.push_back(Vec);
746 }
747
expandVector(Instruction * InsPos,std::vector<Value * > & Ops,size_t VecPos)748 void expandVector(Instruction *InsPos, std::vector<Value *> &Ops,
749 size_t VecPos) {
750 auto Vec = Ops[VecPos];
751 auto *VT = dyn_cast<FixedVectorType>(Vec->getType());
752 if (!VT)
753 return;
754 size_t N = VT->getNumElements();
755 IRBuilder<> Builder(InsPos);
756 for (size_t I = 0; I != N; ++I)
757 Ops.insert(Ops.begin() + VecPos + I,
758 Builder.CreateExtractElement(
759 Vec, ConstantInt::get(Type::getInt32Ty(InsPos->getContext()),
760 I, false)));
761 Ops.erase(Ops.begin() + VecPos + N);
762 }
763
castToInt8Ptr(Constant * V,unsigned Addr=0)764 Constant *castToInt8Ptr(Constant *V, unsigned Addr = 0) {
765 return ConstantExpr::getBitCast(V, Type::getInt8PtrTy(V->getContext(), Addr));
766 }
767
getInt8PtrTy(PointerType * T)768 PointerType *getInt8PtrTy(PointerType *T) {
769 return Type::getInt8PtrTy(T->getContext(), T->getAddressSpace());
770 }
771
castToInt8Ptr(Value * V,Instruction * Pos)772 Value *castToInt8Ptr(Value *V, Instruction *Pos) {
773 return CastInst::CreatePointerCast(
774 V, getInt8PtrTy(cast<PointerType>(V->getType())), "", Pos);
775 }
776
addBlockBind(Module * M,Function * InvokeFunc,Value * BlkCtx,Value * CtxLen,Value * CtxAlign,Instruction * InsPos,StringRef InstName)777 CallInst *addBlockBind(Module *M, Function *InvokeFunc, Value *BlkCtx,
778 Value *CtxLen, Value *CtxAlign, Instruction *InsPos,
779 StringRef InstName) {
780 auto BlkTy =
781 getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_BLOCK_T, SPIRAS_Private);
782 auto &Ctx = M->getContext();
783 Value *BlkArgs[] = {
784 castToInt8Ptr(InvokeFunc),
785 CtxLen ? CtxLen : UndefValue::get(Type::getInt32Ty(Ctx)),
786 CtxAlign ? CtxAlign : UndefValue::get(Type::getInt32Ty(Ctx)),
787 BlkCtx ? BlkCtx : UndefValue::get(Type::getInt8PtrTy(Ctx))};
788 return addCallInst(M, SPIR_INTRINSIC_BLOCK_BIND, BlkTy, BlkArgs, nullptr,
789 InsPos, nullptr, InstName);
790 }
791
getSizetType(Module * M)792 IntegerType *getSizetType(Module *M) {
793 return IntegerType::getIntNTy(M->getContext(),
794 M->getDataLayout().getPointerSizeInBits(0));
795 }
796
getVoidFuncType(Module * M)797 Type *getVoidFuncType(Module *M) {
798 return FunctionType::get(Type::getVoidTy(M->getContext()), false);
799 }
800
getVoidFuncPtrType(Module * M,unsigned AddrSpace)801 Type *getVoidFuncPtrType(Module *M, unsigned AddrSpace) {
802 return PointerType::get(getVoidFuncType(M), AddrSpace);
803 }
804
getInt64(Module * M,int64_t Value)805 ConstantInt *getInt64(Module *M, int64_t Value) {
806 return ConstantInt::getSigned(Type::getInt64Ty(M->getContext()), Value);
807 }
808
getUInt64(Module * M,uint64_t Value)809 ConstantInt *getUInt64(Module *M, uint64_t Value) {
810 return ConstantInt::get(Type::getInt64Ty(M->getContext()), Value, false);
811 }
812
getFloat32(Module * M,float Value)813 Constant *getFloat32(Module *M, float Value) {
814 return ConstantFP::get(Type::getFloatTy(M->getContext()), Value);
815 }
816
getInt32(Module * M,int Value)817 ConstantInt *getInt32(Module *M, int Value) {
818 return ConstantInt::get(Type::getInt32Ty(M->getContext()), Value, true);
819 }
820
getUInt32(Module * M,unsigned Value)821 ConstantInt *getUInt32(Module *M, unsigned Value) {
822 return ConstantInt::get(Type::getInt32Ty(M->getContext()), Value, false);
823 }
824
getInt(Module * M,int64_t Value)825 ConstantInt *getInt(Module *M, int64_t Value) {
826 return Value >> 32 ? getInt64(M, Value)
827 : getInt32(M, static_cast<int32_t>(Value));
828 }
829
getUInt(Module * M,uint64_t Value)830 ConstantInt *getUInt(Module *M, uint64_t Value) {
831 return Value >> 32 ? getUInt64(M, Value)
832 : getUInt32(M, static_cast<uint32_t>(Value));
833 }
834
getUInt16(Module * M,unsigned short Value)835 ConstantInt *getUInt16(Module *M, unsigned short Value) {
836 return ConstantInt::get(Type::getInt16Ty(M->getContext()), Value, false);
837 }
838
getInt32(Module * M,const std::vector<int> & Values)839 std::vector<Value *> getInt32(Module *M, const std::vector<int> &Values) {
840 std::vector<Value *> V;
841 for (auto &I : Values)
842 V.push_back(getInt32(M, I));
843 return V;
844 }
845
getSizet(Module * M,uint64_t Value)846 ConstantInt *getSizet(Module *M, uint64_t Value) {
847 return ConstantInt::get(getSizetType(M), Value, false);
848 }
849
850 ///////////////////////////////////////////////////////////////////////////////
851 //
852 // Functions for getting metadata
853 //
854 ///////////////////////////////////////////////////////////////////////////////
getMDOperandAsInt(MDNode * N,unsigned I)855 int64_t getMDOperandAsInt(MDNode *N, unsigned I) {
856 return mdconst::dyn_extract<ConstantInt>(N->getOperand(I))->getZExtValue();
857 }
858
859 // Additional helper function to be reused by getMDOperandAs* helpers
getMDOperandOrNull(MDNode * N,unsigned I)860 Metadata *getMDOperandOrNull(MDNode *N, unsigned I) {
861 if (!N)
862 return nullptr;
863 return N->getOperand(I);
864 }
865
getMDOperandAsString(MDNode * N,unsigned I)866 std::string getMDOperandAsString(MDNode *N, unsigned I) {
867 if (auto *Str = dyn_cast_or_null<MDString>(getMDOperandOrNull(N, I)))
868 return Str->getString().str();
869 return "";
870 }
871
getMDOperandAsMDNode(MDNode * N,unsigned I)872 MDNode *getMDOperandAsMDNode(MDNode *N, unsigned I) {
873 return dyn_cast_or_null<MDNode>(getMDOperandOrNull(N, I));
874 }
875
getMDOperandAsType(MDNode * N,unsigned I)876 Type *getMDOperandAsType(MDNode *N, unsigned I) {
877 return cast<ValueAsMetadata>(N->getOperand(I))->getType();
878 }
879
getNamedMDAsStringSet(Module * M,const std::string & MDName)880 std::set<std::string> getNamedMDAsStringSet(Module *M,
881 const std::string &MDName) {
882 NamedMDNode *NamedMD = M->getNamedMetadata(MDName);
883 std::set<std::string> StrSet;
884 if (!NamedMD)
885 return StrSet;
886
887 assert(NamedMD->getNumOperands() > 0 && "Invalid SPIR");
888
889 for (unsigned I = 0, E = NamedMD->getNumOperands(); I != E; ++I) {
890 MDNode *MD = NamedMD->getOperand(I);
891 if (!MD || MD->getNumOperands() == 0)
892 continue;
893 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
894 StrSet.insert(getMDOperandAsString(MD, J));
895 }
896
897 return StrSet;
898 }
899
getSPIRVSource(Module * M)900 std::tuple<unsigned, unsigned, std::string> getSPIRVSource(Module *M) {
901 std::tuple<unsigned, unsigned, std::string> Tup;
902 if (auto N = SPIRVMDWalker(*M).getNamedMD(kSPIRVMD::Source).nextOp())
903 N.get(std::get<0>(Tup))
904 .get(std::get<1>(Tup))
905 .setQuiet(true)
906 .get(std::get<2>(Tup));
907 return Tup;
908 }
909
mapUInt(Module * M,ConstantInt * I,std::function<unsigned (unsigned)> F)910 ConstantInt *mapUInt(Module *M, ConstantInt *I,
911 std::function<unsigned(unsigned)> F) {
912 return ConstantInt::get(I->getType(), F(I->getZExtValue()), false);
913 }
914
mapSInt(Module * M,ConstantInt * I,std::function<int (int)> F)915 ConstantInt *mapSInt(Module *M, ConstantInt *I, std::function<int(int)> F) {
916 return ConstantInt::get(I->getType(), F(I->getSExtValue()), true);
917 }
918
isDecoratedSPIRVFunc(const Function * F,StringRef & UndecoratedName)919 bool isDecoratedSPIRVFunc(const Function *F, StringRef &UndecoratedName) {
920 if (!F->hasName() || !F->getName().startswith(kSPIRVName::Prefix))
921 return false;
922 UndecoratedName = F->getName();
923 return true;
924 }
925
926 /// Get TypePrimitiveEnum for special OpenCL type except opencl.block.
getOCLTypePrimitiveEnum(StringRef TyName)927 SPIR::TypePrimitiveEnum getOCLTypePrimitiveEnum(StringRef TyName) {
928 return StringSwitch<SPIR::TypePrimitiveEnum>(TyName)
929 .Case("opencl.image1d_ro_t", SPIR::PRIMITIVE_IMAGE1D_RO_T)
930 .Case("opencl.image1d_array_ro_t", SPIR::PRIMITIVE_IMAGE1D_ARRAY_RO_T)
931 .Case("opencl.image1d_buffer_ro_t", SPIR::PRIMITIVE_IMAGE1D_BUFFER_RO_T)
932 .Case("opencl.image2d_ro_t", SPIR::PRIMITIVE_IMAGE2D_RO_T)
933 .Case("opencl.image2d_array_ro_t", SPIR::PRIMITIVE_IMAGE2D_ARRAY_RO_T)
934 .Case("opencl.image2d_depth_ro_t", SPIR::PRIMITIVE_IMAGE2D_DEPTH_RO_T)
935 .Case("opencl.image2d_array_depth_ro_t",
936 SPIR::PRIMITIVE_IMAGE2D_ARRAY_DEPTH_RO_T)
937 .Case("opencl.image2d_msaa_ro_t", SPIR::PRIMITIVE_IMAGE2D_MSAA_RO_T)
938 .Case("opencl.image2d_array_msaa_ro_t",
939 SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_RO_T)
940 .Case("opencl.image2d_msaa_depth_ro_t",
941 SPIR::PRIMITIVE_IMAGE2D_MSAA_DEPTH_RO_T)
942 .Case("opencl.image2d_array_msaa_depth_ro_t",
943 SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_DEPTH_RO_T)
944 .Case("opencl.image3d_ro_t", SPIR::PRIMITIVE_IMAGE3D_RO_T)
945 .Case("opencl.image1d_wo_t", SPIR::PRIMITIVE_IMAGE1D_WO_T)
946 .Case("opencl.image1d_array_wo_t", SPIR::PRIMITIVE_IMAGE1D_ARRAY_WO_T)
947 .Case("opencl.image1d_buffer_wo_t", SPIR::PRIMITIVE_IMAGE1D_BUFFER_WO_T)
948 .Case("opencl.image2d_wo_t", SPIR::PRIMITIVE_IMAGE2D_WO_T)
949 .Case("opencl.image2d_array_wo_t", SPIR::PRIMITIVE_IMAGE2D_ARRAY_WO_T)
950 .Case("opencl.image2d_depth_wo_t", SPIR::PRIMITIVE_IMAGE2D_DEPTH_WO_T)
951 .Case("opencl.image2d_array_depth_wo_t",
952 SPIR::PRIMITIVE_IMAGE2D_ARRAY_DEPTH_WO_T)
953 .Case("opencl.image2d_msaa_wo_t", SPIR::PRIMITIVE_IMAGE2D_MSAA_WO_T)
954 .Case("opencl.image2d_array_msaa_wo_t",
955 SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_WO_T)
956 .Case("opencl.image2d_msaa_depth_wo_t",
957 SPIR::PRIMITIVE_IMAGE2D_MSAA_DEPTH_WO_T)
958 .Case("opencl.image2d_array_msaa_depth_wo_t",
959 SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_DEPTH_WO_T)
960 .Case("opencl.image3d_wo_t", SPIR::PRIMITIVE_IMAGE3D_WO_T)
961 .Case("opencl.image1d_rw_t", SPIR::PRIMITIVE_IMAGE1D_RW_T)
962 .Case("opencl.image1d_array_rw_t", SPIR::PRIMITIVE_IMAGE1D_ARRAY_RW_T)
963 .Case("opencl.image1d_buffer_rw_t", SPIR::PRIMITIVE_IMAGE1D_BUFFER_RW_T)
964 .Case("opencl.image2d_rw_t", SPIR::PRIMITIVE_IMAGE2D_RW_T)
965 .Case("opencl.image2d_array_rw_t", SPIR::PRIMITIVE_IMAGE2D_ARRAY_RW_T)
966 .Case("opencl.image2d_depth_rw_t", SPIR::PRIMITIVE_IMAGE2D_DEPTH_RW_T)
967 .Case("opencl.image2d_array_depth_rw_t",
968 SPIR::PRIMITIVE_IMAGE2D_ARRAY_DEPTH_RW_T)
969 .Case("opencl.image2d_msaa_rw_t", SPIR::PRIMITIVE_IMAGE2D_MSAA_RW_T)
970 .Case("opencl.image2d_array_msaa_rw_t",
971 SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_RW_T)
972 .Case("opencl.image2d_msaa_depth_rw_t",
973 SPIR::PRIMITIVE_IMAGE2D_MSAA_DEPTH_RW_T)
974 .Case("opencl.image2d_array_msaa_depth_rw_t",
975 SPIR::PRIMITIVE_IMAGE2D_ARRAY_MSAA_DEPTH_RW_T)
976 .Case("opencl.image3d_rw_t", SPIR::PRIMITIVE_IMAGE3D_RW_T)
977 .Case("opencl.event_t", SPIR::PRIMITIVE_EVENT_T)
978 .Case("opencl.pipe_ro_t", SPIR::PRIMITIVE_PIPE_RO_T)
979 .Case("opencl.pipe_wo_t", SPIR::PRIMITIVE_PIPE_WO_T)
980 .Case("opencl.reserve_id_t", SPIR::PRIMITIVE_RESERVE_ID_T)
981 .Case("opencl.queue_t", SPIR::PRIMITIVE_QUEUE_T)
982 .Case("opencl.clk_event_t", SPIR::PRIMITIVE_CLK_EVENT_T)
983 .Case("opencl.sampler_t", SPIR::PRIMITIVE_SAMPLER_T)
984 .Case("struct.ndrange_t", SPIR::PRIMITIVE_NDRANGE_T)
985 .Case("opencl.intel_sub_group_avc_mce_payload_t",
986 SPIR::PRIMITIVE_SUB_GROUP_AVC_MCE_PAYLOAD_T)
987 .Case("opencl.intel_sub_group_avc_ime_payload_t",
988 SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_PAYLOAD_T)
989 .Case("opencl.intel_sub_group_avc_ref_payload_t",
990 SPIR::PRIMITIVE_SUB_GROUP_AVC_REF_PAYLOAD_T)
991 .Case("opencl.intel_sub_group_avc_sic_payload_t",
992 SPIR::PRIMITIVE_SUB_GROUP_AVC_SIC_PAYLOAD_T)
993 .Case("opencl.intel_sub_group_avc_mce_result_t",
994 SPIR::PRIMITIVE_SUB_GROUP_AVC_MCE_RESULT_T)
995 .Case("opencl.intel_sub_group_avc_ime_result_t",
996 SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_RESULT_T)
997 .Case("opencl.intel_sub_group_avc_ref_result_t",
998 SPIR::PRIMITIVE_SUB_GROUP_AVC_REF_RESULT_T)
999 .Case("opencl.intel_sub_group_avc_sic_result_t",
1000 SPIR::PRIMITIVE_SUB_GROUP_AVC_SIC_RESULT_T)
1001 .Case(
1002 "opencl.intel_sub_group_avc_ime_result_single_reference_streamout_t",
1003 SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_SINGLE_REF_STREAMOUT_T)
1004 .Case("opencl.intel_sub_group_avc_ime_result_dual_reference_streamout_t",
1005 SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_DUAL_REF_STREAMOUT_T)
1006 .Case("opencl.intel_sub_group_avc_ime_single_reference_streamin_t",
1007 SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_SINGLE_REF_STREAMIN_T)
1008 .Case("opencl.intel_sub_group_avc_ime_dual_reference_streamin_t",
1009 SPIR::PRIMITIVE_SUB_GROUP_AVC_IME_DUAL_REF_STREAMIN_T)
1010 .Default(SPIR::PRIMITIVE_NONE);
1011 }
1012 /// Translates LLVM type to descriptor for mangler.
1013 /// \param Signed indicates integer type should be translated as signed.
1014 /// \param VoidPtr indicates i8* should be translated as void*.
transTypeDesc(Type * Ty,const BuiltinArgTypeMangleInfo & Info)1015 static SPIR::RefParamType transTypeDesc(Type *Ty,
1016 const BuiltinArgTypeMangleInfo &Info) {
1017 bool Signed = Info.IsSigned;
1018 unsigned Attr = Info.Attr;
1019 bool VoidPtr = Info.IsVoidPtr;
1020 if (Info.IsEnum)
1021 return SPIR::RefParamType(new SPIR::PrimitiveType(Info.Enum));
1022 if (Info.IsSampler)
1023 return SPIR::RefParamType(
1024 new SPIR::PrimitiveType(SPIR::PRIMITIVE_SAMPLER_T));
1025 if (Info.IsAtomic && !Ty->isPointerTy()) {
1026 BuiltinArgTypeMangleInfo DTInfo = Info;
1027 DTInfo.IsAtomic = false;
1028 return SPIR::RefParamType(new SPIR::AtomicType(transTypeDesc(Ty, DTInfo)));
1029 }
1030 if (auto *IntTy = dyn_cast<IntegerType>(Ty)) {
1031 switch (IntTy->getBitWidth()) {
1032 case 1:
1033 return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_BOOL));
1034 case 8:
1035 return SPIR::RefParamType(new SPIR::PrimitiveType(
1036 Signed ? SPIR::PRIMITIVE_CHAR : SPIR::PRIMITIVE_UCHAR));
1037 case 16:
1038 return SPIR::RefParamType(new SPIR::PrimitiveType(
1039 Signed ? SPIR::PRIMITIVE_SHORT : SPIR::PRIMITIVE_USHORT));
1040 case 32:
1041 return SPIR::RefParamType(new SPIR::PrimitiveType(
1042 Signed ? SPIR::PRIMITIVE_INT : SPIR::PRIMITIVE_UINT));
1043 case 64:
1044 return SPIR::RefParamType(new SPIR::PrimitiveType(
1045 Signed ? SPIR::PRIMITIVE_LONG : SPIR::PRIMITIVE_ULONG));
1046 default:
1047 llvm_unreachable("invliad int size");
1048 }
1049 }
1050 if (Ty->isVoidTy())
1051 return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_VOID));
1052 if (Ty->isHalfTy())
1053 return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_HALF));
1054 if (Ty->isFloatTy())
1055 return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_FLOAT));
1056 if (Ty->isDoubleTy())
1057 return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_DOUBLE));
1058 if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
1059 return SPIR::RefParamType(new SPIR::VectorType(
1060 transTypeDesc(VecTy->getElementType(), Info), VecTy->getNumElements()));
1061 }
1062 if (Ty->isArrayTy()) {
1063 return transTypeDesc(PointerType::get(Ty->getArrayElementType(), 0), Info);
1064 }
1065 if (Ty->isStructTy()) {
1066 auto Name = Ty->getStructName();
1067 std::string Tmp;
1068
1069 if (Name.startswith(kLLVMTypeName::StructPrefix))
1070 Name = Name.drop_front(strlen(kLLVMTypeName::StructPrefix));
1071 if (Name.startswith(kSPIRVTypeName::PrefixAndDelim)) {
1072 Name = Name.substr(sizeof(kSPIRVTypeName::PrefixAndDelim) - 1);
1073 Tmp = Name.str();
1074 auto Pos = Tmp.find(kSPIRVTypeName::Delimiter); // first dot
1075 while (Pos != std::string::npos) {
1076 Tmp[Pos] = '_';
1077 Pos = Tmp.find(kSPIRVTypeName::Delimiter, Pos);
1078 }
1079 Name = Tmp = kSPIRVName::Prefix + Tmp;
1080 }
1081 // ToDo: Create a better unique Name for struct without Name
1082 if (Name.empty()) {
1083 std::ostringstream OS;
1084 OS << reinterpret_cast<size_t>(Ty);
1085 Name = Tmp = std::string("struct_") + OS.str();
1086 }
1087 return SPIR::RefParamType(new SPIR::UserDefinedType(Name.str()));
1088 }
1089
1090 if (Ty->isPointerTy()) {
1091 auto ET = Ty->getPointerElementType();
1092 SPIR::ParamType *EPT = nullptr;
1093 if (isa<FunctionType>(ET)) {
1094 assert(isVoidFuncTy(cast<FunctionType>(ET)) && "Not supported");
1095 EPT = new SPIR::BlockType;
1096 } else if (auto StructTy = dyn_cast<StructType>(ET)) {
1097 LLVM_DEBUG(dbgs() << "ptr to struct: " << *Ty << '\n');
1098 auto TyName = StructTy->getStructName();
1099 if (TyName.startswith(kSPR2TypeName::OCLPrefix)) {
1100 auto DelimPos = TyName.find_first_of(kSPR2TypeName::Delimiter,
1101 strlen(kSPR2TypeName::OCLPrefix));
1102 if (DelimPos != StringRef::npos)
1103 TyName = TyName.substr(0, DelimPos);
1104 }
1105 LLVM_DEBUG(dbgs() << " type Name: " << TyName << '\n');
1106
1107 auto Prim = getOCLTypePrimitiveEnum(TyName);
1108 if (StructTy->isOpaque()) {
1109 if (TyName == "opencl.block") {
1110 auto BlockTy = new SPIR::BlockType;
1111 // Handle block with local memory arguments according to OpenCL 2.0
1112 // spec.
1113 if (Info.IsLocalArgBlock) {
1114 SPIR::RefParamType VoidTyRef(
1115 new SPIR::PrimitiveType(SPIR::PRIMITIVE_VOID));
1116 auto VoidPtrTy = new SPIR::PointerType(VoidTyRef);
1117 VoidPtrTy->setAddressSpace(SPIR::ATTR_LOCAL);
1118 // "__local void *"
1119 BlockTy->setParam(0, SPIR::RefParamType(VoidPtrTy));
1120 // "..."
1121 BlockTy->setParam(1, SPIR::RefParamType(new SPIR::PrimitiveType(
1122 SPIR::PRIMITIVE_VAR_ARG)));
1123 }
1124 EPT = BlockTy;
1125 } else if (Prim != SPIR::PRIMITIVE_NONE) {
1126 if (Prim == SPIR::PRIMITIVE_PIPE_RO_T ||
1127 Prim == SPIR::PRIMITIVE_PIPE_WO_T) {
1128 SPIR::RefParamType OpaqueTyRef(new SPIR::PrimitiveType(Prim));
1129 auto OpaquePtrTy = new SPIR::PointerType(OpaqueTyRef);
1130 OpaquePtrTy->setAddressSpace(getOCLOpaqueTypeAddrSpace(Prim));
1131 EPT = OpaquePtrTy;
1132 } else {
1133 EPT = new SPIR::PrimitiveType(Prim);
1134 }
1135 }
1136 } else if (Prim == SPIR::PRIMITIVE_NDRANGE_T)
1137 // ndrange_t is not opaque type
1138 EPT = new SPIR::PrimitiveType(SPIR::PRIMITIVE_NDRANGE_T);
1139 }
1140 if (EPT)
1141 return SPIR::RefParamType(EPT);
1142
1143 if (VoidPtr && ET->isIntegerTy(8))
1144 ET = Type::getVoidTy(ET->getContext());
1145 auto PT = new SPIR::PointerType(transTypeDesc(ET, Info));
1146 PT->setAddressSpace(static_cast<SPIR::TypeAttributeEnum>(
1147 Ty->getPointerAddressSpace() + (unsigned)SPIR::ATTR_ADDR_SPACE_FIRST));
1148 for (unsigned I = SPIR::ATTR_QUALIFIER_FIRST, E = SPIR::ATTR_QUALIFIER_LAST;
1149 I <= E; ++I)
1150 PT->setQualifier(static_cast<SPIR::TypeAttributeEnum>(I), I & Attr);
1151 return SPIR::RefParamType(PT);
1152 }
1153 LLVM_DEBUG(dbgs() << "[transTypeDesc] " << *Ty << '\n');
1154 assert(0 && "not implemented");
1155 return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_INT));
1156 }
1157
getScalarOrArray(Value * V,unsigned Size,Instruction * Pos)1158 Value *getScalarOrArray(Value *V, unsigned Size, Instruction *Pos) {
1159 if (!V->getType()->isPointerTy())
1160 return V;
1161 auto GEP = cast<GEPOperator>(V);
1162 assert(GEP->getNumOperands() == 3 && "must be a GEP from an array");
1163 assert(GEP->getSourceElementType()->getArrayNumElements() == Size);
1164 assert(dyn_cast<ConstantInt>(GEP->getOperand(1))->getZExtValue() == 0);
1165 assert(dyn_cast<ConstantInt>(GEP->getOperand(2))->getZExtValue() == 0);
1166 return new LoadInst(GEP->getSourceElementType(), GEP->getOperand(0), "", Pos);
1167 }
1168
getScalarOrVectorConstantInt(Type * T,uint64_t V,bool IsSigned)1169 Constant *getScalarOrVectorConstantInt(Type *T, uint64_t V, bool IsSigned) {
1170 if (auto IT = dyn_cast<IntegerType>(T))
1171 return ConstantInt::get(IT, V);
1172 if (auto VT = dyn_cast<FixedVectorType>(T)) {
1173 std::vector<Constant *> EV(
1174 VT->getNumElements(),
1175 getScalarOrVectorConstantInt(VT->getElementType(), V, IsSigned));
1176 return ConstantVector::get(EV);
1177 }
1178 llvm_unreachable("Invalid type");
1179 return nullptr;
1180 }
1181
getScalarOrArrayConstantInt(Instruction * Pos,Type * T,unsigned Len,uint64_t V,bool IsSigned)1182 Value *getScalarOrArrayConstantInt(Instruction *Pos, Type *T, unsigned Len,
1183 uint64_t V, bool IsSigned) {
1184 if (auto IT = dyn_cast<IntegerType>(T)) {
1185 assert(Len == 1 && "Invalid length");
1186 return ConstantInt::get(IT, V, IsSigned);
1187 }
1188 if (auto PT = dyn_cast<PointerType>(T)) {
1189 auto ET = PT->getPointerElementType();
1190 auto AT = ArrayType::get(ET, Len);
1191 std::vector<Constant *> EV(Len, ConstantInt::get(ET, V, IsSigned));
1192 auto CA = ConstantArray::get(AT, EV);
1193 auto Alloca = new AllocaInst(AT, 0, "", Pos);
1194 new StoreInst(CA, Alloca, Pos);
1195 auto Zero = ConstantInt::getNullValue(Type::getInt32Ty(T->getContext()));
1196 Value *Index[] = {Zero, Zero};
1197 auto *Ret = GetElementPtrInst::CreateInBounds(AT, Alloca, Index, "", Pos);
1198 LLVM_DEBUG(dbgs() << "[getScalarOrArrayConstantInt] Alloca: " << *Alloca
1199 << ", Return: " << *Ret << '\n');
1200 return Ret;
1201 }
1202 if (auto AT = dyn_cast<ArrayType>(T)) {
1203 auto ET = AT->getArrayElementType();
1204 assert(AT->getArrayNumElements() == Len);
1205 std::vector<Constant *> EV(Len, ConstantInt::get(ET, V, IsSigned));
1206 auto Ret = ConstantArray::get(AT, EV);
1207 LLVM_DEBUG(dbgs() << "[getScalarOrArrayConstantInt] Array type: " << *AT
1208 << ", Return: " << *Ret << '\n');
1209 return Ret;
1210 }
1211 llvm_unreachable("Invalid type");
1212 return nullptr;
1213 }
1214
dumpUsers(Value * V,StringRef Prompt)1215 void dumpUsers(Value *V, StringRef Prompt) {
1216 if (!V)
1217 return;
1218 LLVM_DEBUG(dbgs() << Prompt << " Users of " << *V << " :\n");
1219 for (auto UI = V->user_begin(), UE = V->user_end(); UI != UE; ++UI)
1220 LLVM_DEBUG(dbgs() << " " << **UI << '\n');
1221 }
1222
getSPIRVTypeName(StringRef BaseName,StringRef Postfixes)1223 std::string getSPIRVTypeName(StringRef BaseName, StringRef Postfixes) {
1224 assert(!BaseName.empty() && "Invalid SPIR-V type Name");
1225 auto TN = std::string(kSPIRVTypeName::PrefixAndDelim) + BaseName.str();
1226 if (Postfixes.empty())
1227 return TN;
1228 return TN + kSPIRVTypeName::Delimiter + Postfixes.str();
1229 }
1230
isSPIRVConstantName(StringRef TyName)1231 bool isSPIRVConstantName(StringRef TyName) {
1232 if (TyName == getSPIRVTypeName(kSPIRVTypeName::ConstantSampler) ||
1233 TyName == getSPIRVTypeName(kSPIRVTypeName::ConstantPipeStorage))
1234 return true;
1235
1236 return false;
1237 }
1238
getSPIRVTypeByChangeBaseTypeName(Module * M,Type * T,StringRef OldName,StringRef NewName)1239 Type *getSPIRVTypeByChangeBaseTypeName(Module *M, Type *T, StringRef OldName,
1240 StringRef NewName) {
1241 StringRef Postfixes;
1242 if (isSPIRVType(T, OldName, &Postfixes))
1243 return getOrCreateOpaquePtrType(M, getSPIRVTypeName(NewName, Postfixes));
1244 LLVM_DEBUG(dbgs() << " Invalid SPIR-V type " << *T << '\n');
1245 llvm_unreachable("Invalid SPIR-V type");
1246 return nullptr;
1247 }
1248
getSPIRVImageTypePostfixes(StringRef SampledType,SPIRVTypeImageDescriptor Desc,SPIRVAccessQualifierKind Acc)1249 std::string getSPIRVImageTypePostfixes(StringRef SampledType,
1250 SPIRVTypeImageDescriptor Desc,
1251 SPIRVAccessQualifierKind Acc) {
1252 std::string S;
1253 raw_string_ostream OS(S);
1254 OS << kSPIRVTypeName::PostfixDelim << SampledType
1255 << kSPIRVTypeName::PostfixDelim << Desc.Dim << kSPIRVTypeName::PostfixDelim
1256 << Desc.Depth << kSPIRVTypeName::PostfixDelim << Desc.Arrayed
1257 << kSPIRVTypeName::PostfixDelim << Desc.MS << kSPIRVTypeName::PostfixDelim
1258 << Desc.Sampled << kSPIRVTypeName::PostfixDelim << Desc.Format
1259 << kSPIRVTypeName::PostfixDelim << Acc;
1260 return OS.str();
1261 }
1262
getSPIRVImageSampledTypeName(SPIRVType * Ty)1263 std::string getSPIRVImageSampledTypeName(SPIRVType *Ty) {
1264 switch (Ty->getOpCode()) {
1265 case OpTypeVoid:
1266 return kSPIRVImageSampledTypeName::Void;
1267 case OpTypeInt:
1268 if (Ty->getIntegerBitWidth() == 32) {
1269 if (static_cast<SPIRVTypeInt *>(Ty)->isSigned())
1270 return kSPIRVImageSampledTypeName::Int;
1271 else
1272 return kSPIRVImageSampledTypeName::UInt;
1273 }
1274 break;
1275 case OpTypeFloat:
1276 switch (Ty->getFloatBitWidth()) {
1277 case 16:
1278 return kSPIRVImageSampledTypeName::Half;
1279 case 32:
1280 return kSPIRVImageSampledTypeName::Float;
1281 default:
1282 break;
1283 }
1284 break;
1285 default:
1286 break;
1287 }
1288 llvm_unreachable("Invalid sampled type for image");
1289 return std::string();
1290 }
1291
1292 // ToDo: Find a way to represent uint sampled type in LLVM, maybe an
1293 // opaque type.
getLLVMTypeForSPIRVImageSampledTypePostfix(StringRef Postfix,LLVMContext & Ctx)1294 Type *getLLVMTypeForSPIRVImageSampledTypePostfix(StringRef Postfix,
1295 LLVMContext &Ctx) {
1296 if (Postfix == kSPIRVImageSampledTypeName::Void)
1297 return Type::getVoidTy(Ctx);
1298 if (Postfix == kSPIRVImageSampledTypeName::Float)
1299 return Type::getFloatTy(Ctx);
1300 if (Postfix == kSPIRVImageSampledTypeName::Half)
1301 return Type::getHalfTy(Ctx);
1302 if (Postfix == kSPIRVImageSampledTypeName::Int ||
1303 Postfix == kSPIRVImageSampledTypeName::UInt)
1304 return Type::getInt32Ty(Ctx);
1305 llvm_unreachable("Invalid sampled type postfix");
1306 return nullptr;
1307 }
1308
getImageBaseTypeName(StringRef Name)1309 std::string getImageBaseTypeName(StringRef Name) {
1310
1311 SmallVector<StringRef, 4> SubStrs;
1312 const char Delims[] = {kSPR2TypeName::Delimiter, 0};
1313 Name.split(SubStrs, Delims);
1314 if (Name.startswith(kSPR2TypeName::OCLPrefix)) {
1315 Name = SubStrs[1];
1316 } else {
1317 Name = SubStrs[0];
1318 }
1319
1320 std::string ImageTyName{Name};
1321 if (hasAccessQualifiedName(Name))
1322 ImageTyName.erase(ImageTyName.size() - 5, 3);
1323
1324 return ImageTyName;
1325 }
1326
mapOCLTypeNameToSPIRV(StringRef Name,StringRef Acc)1327 std::string mapOCLTypeNameToSPIRV(StringRef Name, StringRef Acc) {
1328 std::string BaseTy;
1329 std::string Postfixes;
1330 raw_string_ostream OS(Postfixes);
1331 if (Name.startswith(kSPR2TypeName::ImagePrefix)) {
1332 std::string ImageTyName = getImageBaseTypeName(Name);
1333 auto Desc = map<SPIRVTypeImageDescriptor>(ImageTyName);
1334 LLVM_DEBUG(dbgs() << "[trans image type] " << Name << " => "
1335 << "(" << (unsigned)Desc.Dim << ", " << Desc.Depth << ", "
1336 << Desc.Arrayed << ", " << Desc.MS << ", " << Desc.Sampled
1337 << ", " << Desc.Format << ")\n");
1338
1339 BaseTy = kSPIRVTypeName::Image;
1340 OS << getSPIRVImageTypePostfixes(
1341 kSPIRVImageSampledTypeName::Void, Desc,
1342 SPIRSPIRVAccessQualifierMap::map(Acc.str()));
1343 } else {
1344 LLVM_DEBUG(dbgs() << "Mapping of " << Name << " is not implemented\n");
1345 llvm_unreachable("Not implemented");
1346 }
1347 return getSPIRVTypeName(BaseTy, OS.str());
1348 }
1349
eraseIfNoUse(Function * F)1350 bool eraseIfNoUse(Function *F) {
1351 bool Changed = false;
1352 if (!F)
1353 return Changed;
1354 if (!GlobalValue::isInternalLinkage(F->getLinkage()) && !F->isDeclaration())
1355 return Changed;
1356
1357 dumpUsers(F, "[eraseIfNoUse] ");
1358 for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
1359 auto U = *UI++;
1360 if (auto CE = dyn_cast<ConstantExpr>(U)) {
1361 if (CE->use_empty()) {
1362 CE->dropAllReferences();
1363 Changed = true;
1364 }
1365 }
1366 }
1367 if (F->use_empty()) {
1368 LLVM_DEBUG(dbgs() << "Erase "; F->printAsOperand(dbgs()); dbgs() << '\n');
1369 F->eraseFromParent();
1370 Changed = true;
1371 }
1372 return Changed;
1373 }
1374
eraseIfNoUse(Value * V)1375 void eraseIfNoUse(Value *V) {
1376 if (!V->use_empty())
1377 return;
1378 if (Constant *C = dyn_cast<Constant>(V)) {
1379 C->destroyConstant();
1380 return;
1381 }
1382 if (Instruction *I = dyn_cast<Instruction>(V)) {
1383 if (!I->mayHaveSideEffects())
1384 I->eraseFromParent();
1385 }
1386 eraseIfNoUse(dyn_cast<Function>(V));
1387 }
1388
eraseUselessFunctions(Module * M)1389 bool eraseUselessFunctions(Module *M) {
1390 bool Changed = false;
1391 for (auto I = M->begin(), E = M->end(); I != E;)
1392 Changed |= eraseIfNoUse(&(*I++));
1393 return Changed;
1394 }
1395
1396 // The mangling algorithm follows OpenCL pipe built-ins clang 3.8 CodeGen rules.
1397 static SPIR::MangleError
manglePipeOrAddressSpaceCastBuiltin(const SPIR::FunctionDescriptor & Fd,std::string & MangledName)1398 manglePipeOrAddressSpaceCastBuiltin(const SPIR::FunctionDescriptor &Fd,
1399 std::string &MangledName) {
1400 assert(OCLUtil::isPipeOrAddressSpaceCastBI(Fd.Name) &&
1401 "Method is expected to be called only for pipe and address space cast "
1402 "builtins!");
1403 if (Fd.isNull()) {
1404 MangledName.assign(SPIR::FunctionDescriptor::nullString());
1405 return SPIR::MANGLE_NULL_FUNC_DESCRIPTOR;
1406 }
1407 MangledName.assign("__" + Fd.Name);
1408 return SPIR::MANGLE_SUCCESS;
1409 }
1410
mangleBuiltin(StringRef UniqName,ArrayRef<Type * > ArgTypes,BuiltinFuncMangleInfo * BtnInfo)1411 std::string mangleBuiltin(StringRef UniqName, ArrayRef<Type *> ArgTypes,
1412 BuiltinFuncMangleInfo *BtnInfo) {
1413 if (!BtnInfo)
1414 return std::string(UniqName);
1415 BtnInfo->init(UniqName);
1416 std::string MangledName;
1417 LLVM_DEBUG(dbgs() << "[mangle] " << UniqName << " => ");
1418 SPIR::FunctionDescriptor FD;
1419 FD.Name = BtnInfo->getUnmangledName();
1420 bool BIVarArgNegative = BtnInfo->getVarArg() < 0;
1421
1422 if (ArgTypes.empty()) {
1423 // Function signature cannot be ()(void, ...) so if there is an ellipsis
1424 // it must be ()(...)
1425 if (BIVarArgNegative) {
1426 FD.Parameters.emplace_back(
1427 SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_VOID)));
1428 }
1429 } else {
1430 for (unsigned I = 0, E = BIVarArgNegative ? ArgTypes.size()
1431 : (unsigned)BtnInfo->getVarArg();
1432 I != E; ++I) {
1433 auto T = ArgTypes[I];
1434 FD.Parameters.emplace_back(
1435 transTypeDesc(T, BtnInfo->getTypeMangleInfo(I)));
1436 }
1437 }
1438 // Ellipsis must be the last argument of any function
1439 if (!BIVarArgNegative) {
1440 assert((unsigned)BtnInfo->getVarArg() <= ArgTypes.size() &&
1441 "invalid index of an ellipsis");
1442 FD.Parameters.emplace_back(
1443 SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_VAR_ARG)));
1444 }
1445
1446 #if defined(SPIRV_SPIR20_MANGLING_REQUIREMENTS)
1447 SPIR::NameMangler Mangler(SPIR::SPIR20);
1448 Mangler.mangle(FD, MangledName);
1449 #else
1450 if (OCLUtil::isPipeOrAddressSpaceCastBI(BtnInfo->getUnmangledName())) {
1451 manglePipeOrAddressSpaceCastBuiltin(FD, MangledName);
1452 } else {
1453 SPIR::NameMangler Mangler(SPIR::SPIR20);
1454 Mangler.mangle(FD, MangledName);
1455 }
1456 #endif
1457
1458 LLVM_DEBUG(dbgs() << MangledName << '\n');
1459 return MangledName;
1460 }
1461
1462 /// Check if access qualifier is encoded in the type Name.
hasAccessQualifiedName(StringRef TyName)1463 bool hasAccessQualifiedName(StringRef TyName) {
1464 if (TyName.size() < 5)
1465 return false;
1466 auto Acc = TyName.substr(TyName.size() - 5, 3);
1467 return llvm::StringSwitch<bool>(Acc)
1468 .Case(kAccessQualPostfix::ReadOnly, true)
1469 .Case(kAccessQualPostfix::WriteOnly, true)
1470 .Case(kAccessQualPostfix::ReadWrite, true)
1471 .Default(false);
1472 }
1473
getAccessQualifier(StringRef TyName)1474 SPIRVAccessQualifierKind getAccessQualifier(StringRef TyName) {
1475 return SPIRSPIRVAccessQualifierMap::map(
1476 getAccessQualifierFullName(TyName).str());
1477 }
1478
getAccessQualifierPostfix(SPIRVAccessQualifierKind Access)1479 StringRef getAccessQualifierPostfix(SPIRVAccessQualifierKind Access) {
1480 switch (Access) {
1481 case AccessQualifierReadOnly:
1482 return kAccessQualPostfix::ReadOnly;
1483 case AccessQualifierWriteOnly:
1484 return kAccessQualPostfix::WriteOnly;
1485 case AccessQualifierReadWrite:
1486 return kAccessQualPostfix::ReadWrite;
1487 default:
1488 assert(false && "Unrecognized access qualifier!");
1489 return kAccessQualPostfix::ReadWrite;
1490 }
1491 }
1492
1493 /// Get access qualifier from the type Name.
getAccessQualifierFullName(StringRef TyName)1494 StringRef getAccessQualifierFullName(StringRef TyName) {
1495 assert(hasAccessQualifiedName(TyName) &&
1496 "Type is not qualified with access.");
1497 auto Acc = TyName.substr(TyName.size() - 5, 3);
1498 return llvm::StringSwitch<StringRef>(Acc)
1499 .Case(kAccessQualPostfix::ReadOnly, kAccessQualName::ReadOnly)
1500 .Case(kAccessQualPostfix::WriteOnly, kAccessQualName::WriteOnly)
1501 .Case(kAccessQualPostfix::ReadWrite, kAccessQualName::ReadWrite);
1502 }
1503
1504 /// Translates OpenCL image type names to SPIR-V.
getSPIRVImageTypeFromOCL(Module * M,Type * ImageTy)1505 Type *getSPIRVImageTypeFromOCL(Module *M, Type *ImageTy) {
1506 assert(isOCLImageType(ImageTy) && "Unsupported type");
1507 auto ImageTypeName = ImageTy->getPointerElementType()->getStructName();
1508 StringRef Acc = kAccessQualName::ReadOnly;
1509 if (hasAccessQualifiedName(ImageTypeName))
1510 Acc = getAccessQualifierFullName(ImageTypeName);
1511 return getOrCreateOpaquePtrType(M, mapOCLTypeNameToSPIRV(ImageTypeName, Acc));
1512 }
1513
getOCLClkEventType(Module * M)1514 llvm::PointerType *getOCLClkEventType(Module *M) {
1515 return getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_CLK_EVENT_T,
1516 SPIRAS_Private);
1517 }
1518
getOCLClkEventPtrType(Module * M)1519 llvm::PointerType *getOCLClkEventPtrType(Module *M) {
1520 return PointerType::get(getOCLClkEventType(M), SPIRAS_Generic);
1521 }
1522
getOCLNullClkEventPtr(Module * M)1523 llvm::Constant *getOCLNullClkEventPtr(Module *M) {
1524 return Constant::getNullValue(getOCLClkEventPtrType(M));
1525 }
1526
hasLoopMetadata(const Module * M)1527 bool hasLoopMetadata(const Module *M) {
1528 for (const Function &F : *M)
1529 for (const BasicBlock &BB : F) {
1530 const Instruction *Term = BB.getTerminator();
1531 if (Term && Term->getMetadata("llvm.loop"))
1532 return true;
1533 }
1534 return false;
1535 }
1536
isSPIRVOCLExtInst(const CallInst * CI,OCLExtOpKind * ExtOp)1537 bool isSPIRVOCLExtInst(const CallInst *CI, OCLExtOpKind *ExtOp) {
1538 StringRef DemangledName;
1539 if (!oclIsBuiltin(CI->getCalledFunction()->getName(), DemangledName))
1540 return false;
1541 StringRef S = DemangledName;
1542 if (!S.startswith(kSPIRVName::Prefix))
1543 return false;
1544 S = S.drop_front(strlen(kSPIRVName::Prefix));
1545 auto Loc = S.find(kSPIRVPostfix::Divider);
1546 auto ExtSetName = S.substr(0, Loc);
1547 SPIRVExtInstSetKind Set = SPIRVEIS_Count;
1548 if (!SPIRVExtSetShortNameMap::rfind(ExtSetName.str(), &Set))
1549 return false;
1550
1551 if (Set != SPIRVEIS_OpenCL)
1552 return false;
1553
1554 auto ExtOpName = S.substr(Loc + 1);
1555 auto PostFixPos = ExtOpName.find("_R");
1556 ExtOpName = ExtOpName.substr(0, PostFixPos);
1557
1558 OCLExtOpKind EOC;
1559 if (!OCLExtOpMap::rfind(ExtOpName.str(), &EOC))
1560 return false;
1561
1562 *ExtOp = EOC;
1563 return true;
1564 }
1565
decodeSPIRVTypeName(StringRef Name,SmallVectorImpl<std::string> & Strs)1566 std::string decodeSPIRVTypeName(StringRef Name,
1567 SmallVectorImpl<std::string> &Strs) {
1568 SmallVector<StringRef, 4> SubStrs;
1569 const char Delim[] = {kSPIRVTypeName::Delimiter, 0};
1570 Name.split(SubStrs, Delim, -1, true);
1571 assert(SubStrs.size() >= 2 && "Invalid SPIRV type name");
1572 assert(SubStrs[0] == kSPIRVTypeName::Prefix && "Invalid prefix");
1573 assert((SubStrs.size() == 2 || !SubStrs[2].empty()) && "Invalid postfix");
1574
1575 if (SubStrs.size() > 2) {
1576 const char PostDelim[] = {kSPIRVTypeName::PostfixDelim, 0};
1577 SmallVector<StringRef, 4> Postfixes;
1578 SubStrs[2].split(Postfixes, PostDelim, -1, true);
1579 assert(Postfixes.size() > 1 && Postfixes[0].empty() && "Invalid postfix");
1580 for (unsigned I = 1, E = Postfixes.size(); I != E; ++I)
1581 Strs.push_back(std::string(Postfixes[I]).c_str());
1582 }
1583 return SubStrs[1].str();
1584 }
1585
1586 // Returns true if type(s) and number of elements (if vector) is valid
checkTypeForSPIRVExtendedInstLowering(IntrinsicInst * II,SPIRVModule * BM)1587 bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
1588 switch (II->getIntrinsicID()) {
1589 case Intrinsic::ceil:
1590 case Intrinsic::copysign:
1591 case Intrinsic::cos:
1592 case Intrinsic::exp:
1593 case Intrinsic::exp2:
1594 case Intrinsic::fabs:
1595 case Intrinsic::floor:
1596 case Intrinsic::fma:
1597 case Intrinsic::log:
1598 case Intrinsic::log10:
1599 case Intrinsic::log2:
1600 case Intrinsic::maximum:
1601 case Intrinsic::maxnum:
1602 case Intrinsic::minimum:
1603 case Intrinsic::minnum:
1604 case Intrinsic::nearbyint:
1605 case Intrinsic::pow:
1606 case Intrinsic::powi:
1607 case Intrinsic::rint:
1608 case Intrinsic::round:
1609 case Intrinsic::roundeven:
1610 case Intrinsic::sin:
1611 case Intrinsic::sqrt:
1612 case Intrinsic::trunc: {
1613 // Although some of the intrinsics above take multiple arguments, it is
1614 // sufficient to check arg 0 because the LLVM Verifier will have checked
1615 // that all floating point operands have the same type and the second
1616 // argument of powi is i32.
1617 Type *Ty = II->getType();
1618 if (II->getArgOperand(0)->getType() != Ty)
1619 return false;
1620 int NumElems = 1;
1621 if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
1622 NumElems = VecTy->getNumElements();
1623 Ty = VecTy->getElementType();
1624 }
1625 if ((!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy()) ||
1626 ((NumElems > 4) && (NumElems != 8) && (NumElems != 16))) {
1627 BM->SPIRVCK(
1628 false, InvalidFunctionCall, II->getCalledOperand()->getName().str());
1629 return false;
1630 }
1631 break;
1632 }
1633 case Intrinsic::abs: {
1634 Type *Ty = II->getType();
1635 int NumElems = 1;
1636 if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
1637 NumElems = VecTy->getNumElements();
1638 Ty = VecTy->getElementType();
1639 }
1640 if ((!Ty->isIntegerTy()) ||
1641 ((NumElems > 4) && (NumElems != 8) && (NumElems != 16))) {
1642 BM->SPIRVCK(
1643 false, InvalidFunctionCall, II->getCalledOperand()->getName().str());
1644 }
1645 break;
1646 }
1647 default:
1648 break;
1649 }
1650 return true;
1651 }
1652
setAttrByCalledFunc(CallInst * Call)1653 void setAttrByCalledFunc(CallInst *Call) {
1654 Function *F = Call->getCalledFunction();
1655 assert(F);
1656 if (F->isIntrinsic()) {
1657 return;
1658 }
1659 Call->setCallingConv(F->getCallingConv());
1660 Call->setAttributes(F->getAttributes());
1661 }
1662
isSPIRVBuiltinVariable(GlobalVariable * GV,SPIRVBuiltinVariableKind * Kind)1663 bool isSPIRVBuiltinVariable(GlobalVariable *GV,
1664 SPIRVBuiltinVariableKind *Kind) {
1665 if (!GV->hasName() || !getSPIRVBuiltin(GV->getName().str(), *Kind))
1666 return false;
1667 return true;
1668 }
1669
1670 // Variable like GlobalInvolcationId[x] -> get_global_id(x).
1671 // Variable like WorkDim -> get_work_dim().
1672 // Replace the following pattern:
1673 // %a = addrspacecast i32 addrspace(1)* @__spirv_BuiltInSubgroupMaxSize to
1674 // i32 addrspace(4)*
1675 // %b = load i32, i32 addrspace(4)* %a, align 4
1676 // %c = load i32, i32 addrspace(4)* %a, align 4
1677 // With:
1678 // %b = call spir_func i32 @_Z22get_max_sub_group_sizev()
1679 // %c = call spir_func i32 @_Z22get_max_sub_group_sizev()
1680
1681 // And replace the following pattern:
1682 // %a = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupId to
1683 // <3 x i64> addrspace(4)*
1684 // %b = load <3 x i64>, <3 x i64> addrspace(4)* %a, align 32
1685 // %c = extractelement <3 x i64> %b, i32 idx
1686 // %d = extractelement <3 x i64> %b, i32 idx
1687 // With:
1688 // %0 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
1689 // %1 = insertelement <3 x i64> undef, i64 %0, i32 0
1690 // %2 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
1691 // %3 = insertelement <3 x i64> %1, i64 %2, i32 1
1692 // %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
1693 // %5 = insertelement <3 x i64> %3, i64 %4, i32 2
1694 // %c = extractelement <3 x i64> %5, i32 idx
1695 // %d = extractelement <3 x i64> %5, i32 idx
1696 //
1697 // Replace the following pattern:
1698 // %0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to
1699 // <3 x i64> addrspace(4)*
1700 // %1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
1701 // %2 = load i64, i64 addrspace(4)* %1, align 32
1702 // With:
1703 // %0 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
1704 // %1 = insertelement <3 x i64> undef, i64 %0, i32 0
1705 // %2 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
1706 // %3 = insertelement <3 x i64> %1, i64 %2, i32 1
1707 // %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
1708 // %5 = insertelement <3 x i64> %3, i64 %4, i32 2
1709 // %6 = extractelement <3 x i64> %5, i32 0
lowerBuiltinVariableToCall(GlobalVariable * GV,SPIRVBuiltinVariableKind Kind)1710 bool lowerBuiltinVariableToCall(GlobalVariable *GV,
1711 SPIRVBuiltinVariableKind Kind) {
1712 // There might be dead constant users of GV (for example, SPIRVLowerConstExpr
1713 // replaces ConstExpr uses but those ConstExprs are not deleted, since LLVM
1714 // constants are created on demand as needed and never deleted).
1715 // Remove them first!
1716 GV->removeDeadConstantUsers();
1717
1718 Module *M = GV->getParent();
1719 LLVMContext &C = M->getContext();
1720 std::string FuncName = GV->getName().str();
1721 Type *GVTy = GV->getType()->getPointerElementType();
1722 Type *ReturnTy = GVTy;
1723 // Some SPIR-V builtin variables are translated to a function with an index
1724 // argument.
1725 bool HasIndexArg =
1726 ReturnTy->isVectorTy() &&
1727 !(BuiltInSubgroupEqMask <= Kind && Kind <= BuiltInSubgroupLtMask);
1728 if (HasIndexArg)
1729 ReturnTy = cast<VectorType>(ReturnTy)->getElementType();
1730 std::vector<Type *> ArgTy;
1731 if (HasIndexArg)
1732 ArgTy.push_back(Type::getInt32Ty(C));
1733 std::string MangledName;
1734 mangleOpenClBuiltin(FuncName, ArgTy, MangledName);
1735 Function *Func = M->getFunction(MangledName);
1736 if (!Func) {
1737 FunctionType *FT = FunctionType::get(ReturnTy, ArgTy, false);
1738 Func = Function::Create(FT, GlobalValue::ExternalLinkage, MangledName, M);
1739 Func->setCallingConv(CallingConv::SPIR_FUNC);
1740 Func->addFnAttr(Attribute::NoUnwind);
1741 Func->addFnAttr(Attribute::ReadNone);
1742 Func->addFnAttr(Attribute::WillReturn);
1743 }
1744
1745 // Collect instructions in these containers to remove them later.
1746 std::vector<Instruction *> Loads;
1747 std::vector<Instruction *> Casts;
1748 std::vector<Instruction *> GEPs;
1749
1750 auto Replace = [&](std::vector<Value *> Arg, Instruction *I) {
1751 auto *Call = CallInst::Create(Func, Arg, "", I);
1752 Call->takeName(I);
1753 setAttrByCalledFunc(Call);
1754 SPIRVDBG(dbgs() << "[lowerBuiltinVariableToCall] " << *I << " -> " << *Call
1755 << '\n';)
1756 I->replaceAllUsesWith(Call);
1757 };
1758
1759 // If HasIndexArg is true, we create 3 built-in calls and insertelement to
1760 // get 3-element vector filled with ids and replace uses of Load instruction
1761 // with this vector.
1762 // If HasIndexArg is false, the result of the Load instruction is the value
1763 // which should be replaced with the Func.
1764 // Returns true if Load was replaced, false otherwise.
1765 auto ReplaceIfLoad = [&](User *I) {
1766 auto *LD = dyn_cast<LoadInst>(I);
1767 if (!LD)
1768 return false;
1769 std::vector<Value *> Vectors;
1770 Loads.push_back(LD);
1771 if (HasIndexArg) {
1772 auto *VecTy = cast<FixedVectorType>(GVTy);
1773 Value *EmptyVec = UndefValue::get(VecTy);
1774 Vectors.push_back(EmptyVec);
1775 const DebugLoc &DLoc = LD->getDebugLoc();
1776 for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
1777 auto *Idx = ConstantInt::get(Type::getInt32Ty(C), I);
1778 auto *Call = CallInst::Create(Func, {Idx}, "", LD);
1779 if (DLoc)
1780 Call->setDebugLoc(DLoc);
1781 setAttrByCalledFunc(Call);
1782 auto *Insert = InsertElementInst::Create(Vectors.back(), Call, Idx);
1783 if (DLoc)
1784 Insert->setDebugLoc(DLoc);
1785 Insert->insertAfter(Call);
1786 Vectors.push_back(Insert);
1787 }
1788
1789 Value *Ptr = LD->getPointerOperand();
1790
1791 if (isa<FixedVectorType>(Ptr->getType()->getPointerElementType())) {
1792 LD->replaceAllUsesWith(Vectors.back());
1793 } else {
1794 auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
1795 assert(GEP && "Unexpected pattern!");
1796 assert(GEP->getNumIndices() == 2 && "Unexpected pattern!");
1797 Value *Idx = GEP->getOperand(2);
1798 Value *Vec = Vectors.back();
1799 auto *NewExtract = ExtractElementInst::Create(Vec, Idx);
1800 NewExtract->insertAfter(cast<Instruction>(Vec));
1801 LD->replaceAllUsesWith(NewExtract);
1802 }
1803
1804 } else {
1805 Replace({}, LD);
1806 }
1807
1808 return true;
1809 };
1810
1811 // Go over the GV users, find Load and ExtractElement instructions and
1812 // replace them with the corresponding function call.
1813 for (auto *UI : GV->users()) {
1814 // There might or might not be an addrspacecast instruction.
1815 if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(UI)) {
1816 Casts.push_back(ASCast);
1817 for (auto *CastUser : ASCast->users()) {
1818 if (ReplaceIfLoad(CastUser))
1819 continue;
1820 if (auto *GEP = dyn_cast<GetElementPtrInst>(CastUser)) {
1821 GEPs.push_back(GEP);
1822 for (auto *GEPUser : GEP->users()) {
1823 if (!ReplaceIfLoad(GEPUser))
1824 llvm_unreachable("Unexpected pattern!");
1825 }
1826 } else {
1827 llvm_unreachable("Unexpected pattern!");
1828 }
1829 }
1830 } else if (!ReplaceIfLoad(UI)) {
1831 llvm_unreachable("Unexpected pattern!");
1832 }
1833 }
1834
1835 auto Erase = [](std::vector<Instruction *> &ToErase) {
1836 for (Instruction *I : ToErase) {
1837 assert(I->hasNUses(0));
1838 I->eraseFromParent();
1839 }
1840 };
1841 // Order of erasing is important.
1842 Erase(Loads);
1843 Erase(GEPs);
1844 Erase(Casts);
1845
1846 return true;
1847 }
1848
lowerBuiltinVariablesToCalls(Module * M)1849 bool lowerBuiltinVariablesToCalls(Module *M) {
1850 std::vector<GlobalVariable *> WorkList;
1851 for (auto I = M->global_begin(), E = M->global_end(); I != E; ++I) {
1852 SPIRVBuiltinVariableKind Kind;
1853 if (!isSPIRVBuiltinVariable(&(*I), &Kind))
1854 continue;
1855 if (!lowerBuiltinVariableToCall(&(*I), Kind))
1856 return false;
1857 WorkList.push_back(&(*I));
1858 }
1859 for (auto &I : WorkList) {
1860 I->eraseFromParent();
1861 }
1862
1863 return true;
1864 }
1865
postProcessBuiltinReturningStruct(Function * F)1866 bool postProcessBuiltinReturningStruct(Function *F) {
1867 Module *M = F->getParent();
1868 LLVMContext *Context = &M->getContext();
1869 std::string Name = F->getName().str();
1870 F->setName(Name + ".old");
1871 SmallVector<Instruction *, 32> InstToRemove;
1872 for (auto *U : F->users()) {
1873 if (auto *CI = dyn_cast<CallInst>(U)) {
1874 auto *ST = cast<StoreInst>(*(CI->user_begin()));
1875 std::vector<Type *> ArgTys;
1876 getFunctionTypeParameterTypes(F->getFunctionType(), ArgTys);
1877 ArgTys.insert(ArgTys.begin(),
1878 PointerType::get(F->getReturnType(), SPIRAS_Private));
1879 auto *NewF =
1880 getOrCreateFunction(M, Type::getVoidTy(*Context), ArgTys, Name);
1881 NewF->addParamAttr(0, Attribute::get(*Context,
1882 Attribute::AttrKind::StructRet,
1883 F->getReturnType()));
1884 NewF->setCallingConv(F->getCallingConv());
1885 auto Args = getArguments(CI);
1886 Args.insert(Args.begin(), ST->getPointerOperand());
1887 auto *NewCI = CallInst::Create(NewF, Args, CI->getName(), CI);
1888 NewCI->setCallingConv(CI->getCallingConv());
1889 InstToRemove.push_back(ST);
1890 InstToRemove.push_back(CI);
1891 }
1892 }
1893 for (auto *Inst : InstToRemove) {
1894 Inst->dropAllReferences();
1895 Inst->eraseFromParent();
1896 }
1897 F->dropAllReferences();
1898 F->eraseFromParent();
1899 return true;
1900 }
1901
postProcessBuiltinWithArrayArguments(Function * F,StringRef DemangledName)1902 bool postProcessBuiltinWithArrayArguments(Function *F,
1903 StringRef DemangledName) {
1904 LLVM_DEBUG(dbgs() << "[postProcessOCLBuiltinWithArrayArguments] " << *F
1905 << '\n');
1906 auto Attrs = F->getAttributes();
1907 auto Name = F->getName();
1908 mutateFunction(
1909 F,
1910 [=](CallInst *CI, std::vector<Value *> &Args) {
1911 auto FBegin = CI->getFunction()->begin()->getFirstInsertionPt();
1912 for (auto &I : Args) {
1913 auto *T = I->getType();
1914 if (!T->isArrayTy())
1915 continue;
1916 auto *Alloca = new AllocaInst(T, 0, "", &(*FBegin));
1917 new StoreInst(I, Alloca, false, CI);
1918 auto *Zero =
1919 ConstantInt::getNullValue(Type::getInt32Ty(T->getContext()));
1920 Value *Index[] = {Zero, Zero};
1921 I = GetElementPtrInst::CreateInBounds(T, Alloca, Index, "", CI);
1922 }
1923 return Name.str();
1924 },
1925 nullptr, &Attrs);
1926 return true;
1927 }
1928
postProcessBuiltinsReturningStruct(Module * M,bool IsCpp)1929 bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp) {
1930 StringRef DemangledName;
1931 // postProcessBuiltinReturningStruct may remove some functions from the
1932 // module, so use make_early_inc_range
1933 for (auto &F : make_early_inc_range(M->functions())) {
1934 if (F.hasName() && F.isDeclaration()) {
1935 LLVM_DEBUG(dbgs() << "[postProcess sret] " << F << '\n');
1936 if (F.getReturnType()->isStructTy() &&
1937 oclIsBuiltin(F.getName(), DemangledName, IsCpp)) {
1938 if (!postProcessBuiltinReturningStruct(&F))
1939 return false;
1940 }
1941 }
1942 }
1943 return true;
1944 }
1945
postProcessBuiltinsWithArrayArguments(Module * M,bool IsCpp)1946 bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp) {
1947 StringRef DemangledName;
1948 // postProcessBuiltinWithArrayArguments may remove some functions from the
1949 // module, so use make_early_inc_range
1950 for (auto &F : make_early_inc_range(M->functions())) {
1951 if (F.hasName() && F.isDeclaration()) {
1952 LLVM_DEBUG(dbgs() << "[postProcess array arg] " << F << '\n');
1953 if (hasArrayArg(&F) && oclIsBuiltin(F.getName(), DemangledName, IsCpp))
1954 if (!postProcessBuiltinWithArrayArguments(&F, DemangledName))
1955 return false;
1956 }
1957 }
1958 return true;
1959 }
1960
1961 } // namespace SPIRV
1962
1963 namespace {
1964 class SPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
1965 public:
SPIRVFriendlyIRMangleInfo(spv::Op OC,ArrayRef<Type * > ArgTys)1966 SPIRVFriendlyIRMangleInfo(spv::Op OC, ArrayRef<Type *> ArgTys)
1967 : OC(OC), ArgTys(ArgTys) {}
1968
init(StringRef UniqUnmangledName)1969 void init(StringRef UniqUnmangledName) override {
1970 UnmangledName = UniqUnmangledName.str();
1971 switch (OC) {
1972 case OpConvertUToF:
1973 case OpUConvert:
1974 case OpSatConvertUToS:
1975 // Treat all arguments as unsigned
1976 addUnsignedArg(-1);
1977 break;
1978 case OpSubgroupShuffleINTEL:
1979 case OpSubgroupShuffleXorINTEL:
1980 addUnsignedArg(1);
1981 break;
1982 case OpSubgroupShuffleDownINTEL:
1983 case OpSubgroupShuffleUpINTEL:
1984 addUnsignedArg(2);
1985 break;
1986 case OpSubgroupBlockWriteINTEL:
1987 addUnsignedArg(0);
1988 addUnsignedArg(1);
1989 break;
1990 case OpSubgroupImageBlockWriteINTEL:
1991 addUnsignedArg(2);
1992 break;
1993 case OpSubgroupBlockReadINTEL:
1994 setArgAttr(0, SPIR::ATTR_CONST);
1995 addUnsignedArg(0);
1996 break;
1997 case OpAtomicUMax:
1998 case OpAtomicUMin:
1999 addUnsignedArg(0);
2000 addUnsignedArg(3);
2001 break;
2002 case OpGroupUMax:
2003 case OpGroupUMin:
2004 case OpGroupNonUniformBroadcast:
2005 case OpGroupNonUniformBallotBitCount:
2006 case OpGroupNonUniformShuffle:
2007 case OpGroupNonUniformShuffleXor:
2008 case OpGroupNonUniformShuffleUp:
2009 case OpGroupNonUniformShuffleDown:
2010 addUnsignedArg(2);
2011 break;
2012 case OpGroupNonUniformInverseBallot:
2013 case OpGroupNonUniformBallotFindLSB:
2014 case OpGroupNonUniformBallotFindMSB:
2015 addUnsignedArg(1);
2016 break;
2017 case OpGroupNonUniformBallotBitExtract:
2018 addUnsignedArg(1);
2019 addUnsignedArg(2);
2020 break;
2021 case OpGroupNonUniformIAdd:
2022 case OpGroupNonUniformFAdd:
2023 case OpGroupNonUniformIMul:
2024 case OpGroupNonUniformFMul:
2025 case OpGroupNonUniformSMin:
2026 case OpGroupNonUniformFMin:
2027 case OpGroupNonUniformSMax:
2028 case OpGroupNonUniformFMax:
2029 case OpGroupNonUniformBitwiseAnd:
2030 case OpGroupNonUniformBitwiseOr:
2031 case OpGroupNonUniformBitwiseXor:
2032 case OpGroupNonUniformLogicalAnd:
2033 case OpGroupNonUniformLogicalOr:
2034 case OpGroupNonUniformLogicalXor:
2035 addUnsignedArg(3);
2036 break;
2037 case OpGroupNonUniformUMax:
2038 case OpGroupNonUniformUMin:
2039 addUnsignedArg(2);
2040 addUnsignedArg(3);
2041 break;
2042 default:;
2043 // No special handling is needed
2044 }
2045 }
2046
2047 private:
2048 spv::Op OC;
2049 ArrayRef<Type *> ArgTys;
2050 };
2051 class OpenCLStdToSPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
2052 public:
OpenCLStdToSPIRVFriendlyIRMangleInfo(OCLExtOpKind ExtOpId,ArrayRef<Type * > ArgTys,Type * RetTy)2053 OpenCLStdToSPIRVFriendlyIRMangleInfo(OCLExtOpKind ExtOpId,
2054 ArrayRef<Type *> ArgTys, Type *RetTy)
2055 : ExtOpId(ExtOpId), ArgTys(ArgTys) {
2056
2057 std::string Postfix = "";
2058 if (needRetTypePostfix())
2059 Postfix = kSPIRVPostfix::Divider + getPostfixForReturnType(RetTy, true);
2060
2061 UnmangledName = getSPIRVExtFuncName(SPIRVEIS_OpenCL, ExtOpId, Postfix);
2062 }
2063
needRetTypePostfix()2064 bool needRetTypePostfix() {
2065 switch (ExtOpId) {
2066 case OpenCLLIB::Vload_half:
2067 case OpenCLLIB::Vload_halfn:
2068 case OpenCLLIB::Vloada_halfn:
2069 case OpenCLLIB::Vloadn:
2070 return true;
2071 default:
2072 return false;
2073 }
2074 }
2075
init(StringRef)2076 void init(StringRef) override {
2077 switch (ExtOpId) {
2078 case OpenCLLIB::UAbs:
2079 case OpenCLLIB::UAbs_diff:
2080 case OpenCLLIB::UAdd_sat:
2081 case OpenCLLIB::UHadd:
2082 case OpenCLLIB::URhadd:
2083 case OpenCLLIB::UClamp:
2084 case OpenCLLIB::UMad_hi:
2085 case OpenCLLIB::UMad_sat:
2086 case OpenCLLIB::UMax:
2087 case OpenCLLIB::UMin:
2088 case OpenCLLIB::UMul_hi:
2089 case OpenCLLIB::USub_sat:
2090 case OpenCLLIB::U_Upsample:
2091 case OpenCLLIB::UMad24:
2092 case OpenCLLIB::UMul24:
2093 // Treat all arguments as unsigned
2094 addUnsignedArg(-1);
2095 break;
2096 case OpenCLLIB::S_Upsample:
2097 addUnsignedArg(1);
2098 break;
2099 default:;
2100 // No special handling is needed
2101 }
2102 }
2103
2104 private:
2105 OCLExtOpKind ExtOpId;
2106 ArrayRef<Type *> ArgTys;
2107 };
2108 } // namespace
2109
2110 namespace SPIRV {
getSPIRVFriendlyIRFunctionName(OCLExtOpKind ExtOpId,ArrayRef<Type * > ArgTys,Type * RetTy)2111 std::string getSPIRVFriendlyIRFunctionName(OCLExtOpKind ExtOpId,
2112 ArrayRef<Type *> ArgTys,
2113 Type *RetTy) {
2114 OpenCLStdToSPIRVFriendlyIRMangleInfo MangleInfo(ExtOpId, ArgTys, RetTy);
2115 return mangleBuiltin(MangleInfo.getUnmangledName(), ArgTys, &MangleInfo);
2116 }
2117
getSPIRVFriendlyIRFunctionName(const std::string & UniqName,spv::Op OC,ArrayRef<Type * > ArgTys)2118 std::string getSPIRVFriendlyIRFunctionName(const std::string &UniqName,
2119 spv::Op OC,
2120 ArrayRef<Type *> ArgTys) {
2121 SPIRVFriendlyIRMangleInfo MangleInfo(OC, ArgTys);
2122 return mangleBuiltin(UniqName, ArgTys, &MangleInfo);
2123 }
2124
2125 } // namespace SPIRV
2126