1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 //
10 /// GenXPrintfResolution
11 /// --------------------
12 /// This pass finds every call to printf function and replaces it with a series
13 /// of printf implementation functions from BiF. A proper version of
14 /// implementation (32/64 bit, cm/ocl/ze binary) is provided by outer logic.
15 /// Before:
16 /// %p = call spir_func i32 (i8 as(2)*, ...) @printf(i8 as(2)* %str)
17 /// After:
18 /// %init = call <4 x i32> @__vc_printf_init(<4 x i32> zeroinitializer)
19 /// %fmt = call <4 x i32> @__vc_printf_fmt(<4 x i32> %init, i8 as(2)* %str)
20 /// %printf = call i32 @__vc_printf_ret(<4 x i32> %fmt)
21 ///
22 /// Vector <4 x i32> is passed between functions to transfer their internal
23 /// data. This data is handled by the function implementations themselves,
24 /// this pass knows nothing about it.
25 //===----------------------------------------------------------------------===//
26 
27 #include "vc/BiF/PrintfIface.h"
28 #include "vc/BiF/Tools.h"
29 #include "vc/GenXOpts/GenXOpts.h"
30 #include "vc/Support/BackendConfig.h"
31 #include "vc/Utils/GenX/Printf.h"
32 #include "vc/Utils/General/BiF.h"
33 #include "vc/Utils/General/Types.h"
34 
35 #include <llvm/ADT/STLExtras.h>
36 #include <llvm/ADT/iterator_range.h>
37 #include <llvm/IR/Constants.h>
38 #include <llvm/IR/DataLayout.h>
39 #include <llvm/IR/IRBuilder.h>
40 #include <llvm/IR/InstIterator.h>
41 #include <llvm/IR/Instructions.h>
42 #include <llvm/IR/Module.h>
43 #include <llvm/Linker/Linker.h>
44 #include <llvm/Pass.h>
45 #include <llvm/Support/ErrorHandling.h>
46 
47 #include "llvmWrapper/IR/DerivedTypes.h"
48 
49 #include <algorithm>
50 #include <functional>
51 #include <numeric>
52 #include <sstream>
53 #include <vector>
54 
55 using namespace llvm;
56 using namespace vc;
57 using namespace vc::bif::printf;
58 
59 namespace PrintfImplFunc {
60 enum Enum { Init, Fmt, FmtLegacy, Arg, ArgStr, ArgStrLegacy, Ret, Size };
61 static constexpr const char *Name[Size] = {
62     "__vc_printf_init", "__vc_printf_fmt",     "__vc_printf_fmt_legacy",
63     "__vc_printf_arg",  "__vc_printf_arg_str", "__vc_printf_arg_str_legacy",
64     "__vc_printf_ret"};
65 } // namespace PrintfImplFunc
66 
67 static constexpr int FormatStringAddrSpace = vc::AddrSpace::Constant;
68 static constexpr int LegacyFormatStringAddrSpace = vc::AddrSpace::Private;
69 
70 namespace {
71 class GenXPrintfResolution final : public ModulePass {
72   const DataLayout *DL = nullptr;
73   std::array<FunctionCallee, PrintfImplFunc::Size> PrintfImplDecl;
74 
75 public:
76   static char ID;
GenXPrintfResolution()77   GenXPrintfResolution() : ModulePass(ID) {}
getPassName() const78   StringRef getPassName() const override { return "GenX printf resolution"; }
79   void getAnalysisUsage(AnalysisUsage &AU) const override;
80   bool runOnModule(Module &M) override;
81 
82 private:
83   std::unique_ptr<Module> getBiFModule(LLVMContext &Ctx);
84   void handlePrintfCall(CallInst &OrigPrintf);
85   void addPrintfImplDeclarations(Module &M);
86   void updatePrintfImplDeclarations(Module &M);
87   void preparePrintfImplForInlining();
88   CallInst &createPrintfInitCall(CallInst &OrigPrintf, int FmtStrSize,
89                                  const PrintfArgInfoSeq &ArgsInfo);
90   CallInst &createPrintfFmtCall(CallInst &OrigPrintf, CallInst &InitCall);
91   CallInst &createPrintfArgCall(CallInst &OrigPrintf, CallInst &PrevCall,
92                                 Value &Arg, PrintfArgInfo Info);
93   CallInst &createPrintfArgStrCall(CallInst &OrigPrintf, CallInst &PrevCall,
94                                    Value &Arg);
95   CallInst &createPrintfRetCall(CallInst &OrigPrintf, CallInst &PrevCall);
96 };
97 } // namespace
98 
99 char GenXPrintfResolution::ID = 0;
100 namespace llvm {
101 void initializeGenXPrintfResolutionPass(PassRegistry &);
102 }
103 
104 INITIALIZE_PASS_BEGIN(GenXPrintfResolution, "GenXPrintfResolution",
105                       "GenXPrintfResolution", false, false)
INITIALIZE_PASS_DEPENDENCY(GenXBackendConfig)106 INITIALIZE_PASS_DEPENDENCY(GenXBackendConfig)
107 INITIALIZE_PASS_END(GenXPrintfResolution, "GenXPrintfResolution",
108                     "GenXPrintfResolution", false, false)
109 
110 ModulePass *llvm::createGenXPrintfResolutionPass() {
111   initializeGenXPrintfResolutionPass(*PassRegistry::getPassRegistry());
112   return new GenXPrintfResolution;
113 }
114 
getAnalysisUsage(AnalysisUsage & AU) const115 void GenXPrintfResolution::getAnalysisUsage(AnalysisUsage &AU) const {
116   AU.addRequired<GenXBackendConfig>();
117 }
118 
119 using CallInstRef = std::reference_wrapper<CallInst>;
120 
isPrintfCall(const CallInst & CI)121 static bool isPrintfCall(const CallInst &CI) {
122   auto *CalledFunc = CI.getCalledFunction();
123   if (!CalledFunc)
124     return false;
125   if (!CalledFunc->isDeclaration())
126     return false;
127   return CalledFunc->getName() == "printf" ||
128     CalledFunc->getName().contains("__spirv_ocl_printf");
129 }
130 
isPrintfCall(const Instruction & Inst)131 static bool isPrintfCall(const Instruction &Inst) {
132   if (!isa<CallInst>(Inst))
133     return false;
134   return isPrintfCall(cast<CallInst>(Inst));
135 }
136 
collectWorkload(Module & M)137 static std::vector<CallInstRef> collectWorkload(Module &M) {
138   std::vector<CallInstRef> Workload;
139   for (Function &F : M)
140     llvm::transform(
141         make_filter_range(instructions(F),
142                           [](Instruction &Inst) { return isPrintfCall(Inst); }),
143         std::back_inserter(Workload),
144         [](Instruction &Inst) { return std::ref(cast<CallInst>(Inst)); });
145   return Workload;
146 }
147 
runOnModule(Module & M)148 bool GenXPrintfResolution::runOnModule(Module &M) {
149   DL = &M.getDataLayout();
150 
151   std::vector<CallInstRef> Workload = collectWorkload(M);
152   if (Workload.empty())
153     return false;
154   addPrintfImplDeclarations(M);
155   for (CallInst &CI : Workload)
156     handlePrintfCall(CI);
157 
158   std::unique_ptr<Module> PrintfImplModule = getBiFModule(M.getContext());
159   PrintfImplModule->setDataLayout(M.getDataLayout());
160   PrintfImplModule->setTargetTriple(M.getTargetTriple());
161   if (Linker::linkModules(M, std::move(PrintfImplModule),
162                           Linker::Flags::LinkOnlyNeeded)) {
163     IGC_ASSERT_MESSAGE(0, "Error linking printf implementation builtin module");
164   }
165   updatePrintfImplDeclarations(M);
166   preparePrintfImplForInlining();
167   return true;
168 }
169 
getBiFModule(LLVMContext & Ctx)170 std::unique_ptr<Module> GenXPrintfResolution::getBiFModule(LLVMContext &Ctx) {
171   MemoryBufferRef PrintfBiFModuleBuffer =
172       getAnalysis<GenXBackendConfig>().getBiFModule(BiFKind::VCPrintf);
173   if (!PrintfBiFModuleBuffer.getBufferSize()) {
174     IGC_ASSERT_MESSAGE(
175         vc::bif::disabled(),
176         "printf bif module can be empty only if vc bif was disabled");
177     report_fatal_error("printf implementation module is absent");
178   }
179   return vc::getBiFModuleOrReportError(PrintfBiFModuleBuffer, Ctx);
180 }
181 
assertPrintfCall(const CallInst & CI)182 static void assertPrintfCall(const CallInst &CI) {
183   IGC_ASSERT_MESSAGE(isPrintfCall(CI), "printf call is expected");
184   IGC_ASSERT_MESSAGE(CI.arg_size() > 0,
185                      "printf call must have at least format string argument");
186   (void)CI;
187 }
188 
189 // Returns pair of format string size (including '\0') and argument information.
190 static std::pair<int, PrintfArgInfoSeq>
analyzeFormatString(const Value & FmtStrOp)191 analyzeFormatString(const Value &FmtStrOp) {
192   auto FmtStr = getConstStringFromOperandOptional(FmtStrOp);
193   if (!FmtStr)
194     report_fatal_error(
195         "printf resolution cannot access format string during compile time");
196   return {FmtStr.getValue().size() + 1, parseFormatString(FmtStr.getValue())};
197 }
198 
199 // Marks strings passed as "%s" arguments in printf.
200 // Recursive function, long instruction chains aren't expected.
markStringArgument(Value & Arg)201 static void markStringArgument(Value &Arg) {
202   if (isa<GEPOperator>(Arg)) {
203     auto *String = getConstStringGVFromOperandOptional(Arg);
204     if (!String)
205       report_fatal_error(PrintfStringAccessError);
206     String->addAttribute(PrintfStringVariable);
207     return;
208   }
209   if (isa<SelectInst>(Arg)) {
210     auto &SI = cast<SelectInst>(Arg);
211     // The same value can be potentially accessed by different paths. Though
212     // it is probably OK, since the same string can be marked several times
213     // and the most of the time cases would be simple so it is not that
214     // critical to pass same values several times in some rare complicated
215     // cases.
216     markStringArgument(*SI.getFalseValue());
217     markStringArgument(*SI.getTrueValue());
218     return;
219   }
220   // Only direct use and selection between strings is supported.
221   report_fatal_error(PrintfStringAccessError);
222 }
223 
224 // Marks printf strings: format strings, strings passed as "%s" arguments.
markPrintfStrings(CallInst & OrigPrintf,const PrintfArgInfoSeq & ArgsInfo)225 static void markPrintfStrings(CallInst &OrigPrintf,
226                               const PrintfArgInfoSeq &ArgsInfo) {
227   auto &FormatString =
228       getConstStringGVFromOperand(*OrigPrintf.getArgOperand(0));
229   FormatString.addAttribute(PrintfStringVariable);
230 
231   // Handle string arguments (%s).
232   auto StringArgs = make_filter_range(
233       zip(drop_begin(OrigPrintf.args(), 1), ArgsInfo), [](auto &&ArgWithInfo) {
234         return std::get<const PrintfArgInfo &>(ArgWithInfo).Type ==
235                PrintfArgInfo::String;
236       });
237   for (auto &&[Arg, ArgInfo] : StringArgs)
238     markStringArgument(*Arg.get());
239 }
240 
handlePrintfCall(CallInst & OrigPrintf)241 void GenXPrintfResolution::handlePrintfCall(CallInst &OrigPrintf) {
242   assertPrintfCall(OrigPrintf);
243   auto [FmtStrSize, ArgsInfo] =
244       analyzeFormatString(*OrigPrintf.getArgOperand(0));
245   if (ArgsInfo.size() != OrigPrintf.getNumArgOperands() - 1)
246     report_fatal_error("printf format string and arguments don't correspond");
247 
248   markPrintfStrings(OrigPrintf, ArgsInfo);
249 
250   auto &InitCall = createPrintfInitCall(OrigPrintf, FmtStrSize, ArgsInfo);
251   auto &FmtCall = createPrintfFmtCall(OrigPrintf, InitCall);
252 
253   // FIXME: combine LLVM call args type and format string info in more
254   // intelligent way.
255   auto ArgsWithInfo = zip(ArgsInfo, drop_begin(OrigPrintf.args(), 1));
256   // potentially FmtCall as there may be no arguments
257   auto &LastArgCall = *std::accumulate(
258       ArgsWithInfo.begin(), ArgsWithInfo.end(), &FmtCall,
259       [&OrigPrintf, this](CallInst *PrevCall, auto &&ArgWithInfo) {
260         return &createPrintfArgCall(OrigPrintf, *PrevCall,
261                                     *std::get<Use &>(ArgWithInfo).get(),
262                                     std::get<PrintfArgInfo &>(ArgWithInfo));
263       });
264   auto &RetCall = createPrintfRetCall(OrigPrintf, LastArgCall);
265   RetCall.takeName(&OrigPrintf);
266   OrigPrintf.replaceAllUsesWith(&RetCall);
267   OrigPrintf.eraseFromParent();
268 }
269 
270 using PrintfImplTypeStorage = std::array<FunctionType *, PrintfImplFunc::Size>;
271 
getPrintfImplTypes(LLVMContext & Ctx)272 static PrintfImplTypeStorage getPrintfImplTypes(LLVMContext &Ctx) {
273   auto *TransferDataTy =
274       IGCLLVM::FixedVectorType::get(Type::getInt32Ty(Ctx), TransferDataSize);
275   auto *ArgsInfoTy =
276       IGCLLVM::FixedVectorType::get(Type::getInt32Ty(Ctx), ArgsInfoVector::Size);
277   auto *ArgDataTy = IGCLLVM::FixedVectorType::get(Type::getInt32Ty(Ctx), ArgData::Size);
278   constexpr bool IsVarArg = false;
279 
280   PrintfImplTypeStorage FuncTys;
281   FuncTys[PrintfImplFunc::Init] =
282       FunctionType::get(TransferDataTy, ArgsInfoTy, IsVarArg);
283   FuncTys[PrintfImplFunc::Fmt] = FunctionType::get(
284       TransferDataTy,
285       {TransferDataTy,
286        PointerType::get(Type::getInt8Ty(Ctx), FormatStringAddrSpace)},
287       IsVarArg);
288   FuncTys[PrintfImplFunc::FmtLegacy] = FunctionType::get(
289       TransferDataTy,
290       {TransferDataTy,
291        PointerType::get(Type::getInt8Ty(Ctx), LegacyFormatStringAddrSpace)},
292       IsVarArg);
293   FuncTys[PrintfImplFunc::Arg] = FunctionType::get(
294       TransferDataTy, {TransferDataTy, Type::getInt32Ty(Ctx), ArgDataTy},
295       IsVarArg);
296   FuncTys[PrintfImplFunc::ArgStr] = FuncTys[PrintfImplFunc::Fmt];
297   FuncTys[PrintfImplFunc::ArgStrLegacy] = FuncTys[PrintfImplFunc::FmtLegacy];
298   FuncTys[PrintfImplFunc::Ret] =
299       FunctionType::get(Type::getInt32Ty(Ctx), TransferDataTy, IsVarArg);
300   return FuncTys;
301 }
302 
addPrintfImplDeclarations(Module & M)303 void GenXPrintfResolution::addPrintfImplDeclarations(Module &M) {
304   auto PrintfImplTy = getPrintfImplTypes(M.getContext());
305 
306   for (int FuncID = 0; FuncID != PrintfImplFunc::Size; ++FuncID)
307     PrintfImplDecl[FuncID] = M.getOrInsertFunction(PrintfImplFunc::Name[FuncID],
308                                                    PrintfImplTy[FuncID]);
309 }
310 
311 // The function must be internal and have always inline attribute for
312 // always-inline pass to inline it and remove the original function body
313 // (the both are critical for GenXPrintfLegalization to work correctly).
preparePrintfImplForInlining()314 void GenXPrintfResolution::preparePrintfImplForInlining() {
315   for (auto Callee : PrintfImplDecl) {
316     auto *Func = cast<Function>(Callee.getCallee());
317     Func->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
318     Func->addFnAttr(Attribute::AlwaysInline);
319   }
320 }
321 
updatePrintfImplDeclarations(Module & M)322 void GenXPrintfResolution::updatePrintfImplDeclarations(Module &M) {
323   std::transform(
324       std::begin(PrintfImplFunc::Name), std::end(PrintfImplFunc::Name),
325       PrintfImplDecl.begin(),
326       [&M](const char *Name) -> FunctionCallee { return M.getFunction(Name); });
327 }
328 
329 using ArgsInfoStorage = std::array<unsigned, ArgsInfoVector::Size>;
330 
331 // Returns arguments information required by init implementation function.
332 // FIXME: combine LLVM call args type and format string info before this
333 // function.
collectArgsInfo(CallInst & OrigPrintf,int FmtStrSize,const PrintfArgInfoSeq & FmtArgsInfo)334 static ArgsInfoStorage collectArgsInfo(CallInst &OrigPrintf, int FmtStrSize,
335                                        const PrintfArgInfoSeq &FmtArgsInfo) {
336   assertPrintfCall(OrigPrintf);
337 
338   ArgsInfoStorage ArgsInfo;
339   // It's not about format string.
340   ArgsInfo[ArgsInfoVector::NumTotal] = OrigPrintf.arg_size() - 1;
341   auto PrintfArgs =
342       make_range(std::next(OrigPrintf.arg_begin()), OrigPrintf.arg_end());
343 
344   ArgsInfo[ArgsInfoVector::Num64Bit] =
345       llvm::count_if(PrintfArgs, [](Value *Arg) {
346         return Arg->getType()->getPrimitiveSizeInBits() == 64;
347       });
348   ArgsInfo[ArgsInfoVector::NumPtr] =
349       llvm::count_if(FmtArgsInfo, [](PrintfArgInfo Info) {
350         return Info.Type == PrintfArgInfo::Pointer;
351       });
352   ArgsInfo[ArgsInfoVector::NumStr] =
353       llvm::count_if(FmtArgsInfo, [](PrintfArgInfo Info) {
354         return Info.Type == PrintfArgInfo::String;
355       });
356   ArgsInfo[ArgsInfoVector::FormatStrSize] = FmtStrSize;
357   return ArgsInfo;
358 }
359 
createPrintfInitCall(CallInst & OrigPrintf,int FmtStrSize,const PrintfArgInfoSeq & FmtArgsInfo)360 CallInst &GenXPrintfResolution::createPrintfInitCall(
361     CallInst &OrigPrintf, int FmtStrSize, const PrintfArgInfoSeq &FmtArgsInfo) {
362   assertPrintfCall(OrigPrintf);
363   auto ImplArgsInfo = collectArgsInfo(OrigPrintf, FmtStrSize, FmtArgsInfo);
364 
365   IRBuilder<> IRB{&OrigPrintf};
366   auto *ArgsInfoV =
367       ConstantDataVector::get(OrigPrintf.getContext(), ImplArgsInfo);
368   return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::Init], ArgsInfoV,
369                          OrigPrintf.getName() + ".printf.init");
370 }
371 
createPrintfFmtCall(CallInst & OrigPrintf,CallInst & InitCall)372 CallInst &GenXPrintfResolution::createPrintfFmtCall(CallInst &OrigPrintf,
373                                                     CallInst &InitCall) {
374   assertPrintfCall(OrigPrintf);
375   IRBuilder<> IRB{&OrigPrintf};
376   auto FmtAS =
377       cast<PointerType>(OrigPrintf.getOperand(0)->getType())->getAddressSpace();
378   if (FmtAS == FormatStringAddrSpace)
379     return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::Fmt],
380                            {&InitCall, OrigPrintf.getOperand(0)},
381                            OrigPrintf.getName() + ".printf.fmt");
382   IGC_ASSERT_MESSAGE(FmtAS == LegacyFormatStringAddrSpace,
383                      "unexpected address space for format string");
384   return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::FmtLegacy],
385                          {&InitCall, OrigPrintf.getOperand(0)},
386                          OrigPrintf.getName() + ".printf.fmt");
387 }
388 
getIntegerArgKind(Type & ArgTy)389 static ArgKind::Enum getIntegerArgKind(Type &ArgTy) {
390   IGC_ASSERT_MESSAGE(ArgTy.isIntegerTy(),
391                      "wrong argument: integer type was expected");
392   auto BitWidth = ArgTy.getIntegerBitWidth();
393   switch (BitWidth) {
394   case 64:
395     return ArgKind::Long;
396   case 32:
397     return ArgKind::Int;
398   case 16:
399     return ArgKind::Short;
400   default:
401     IGC_ASSERT_MESSAGE(BitWidth == 8, "unexpected integer type");
402     return ArgKind::Char;
403   }
404 }
405 
getFloatingPointArgKind(Type & ArgTy)406 static ArgKind::Enum getFloatingPointArgKind(Type &ArgTy) {
407   IGC_ASSERT_MESSAGE(ArgTy.isFloatingPointTy(),
408                      "wrong argument: floating point type was expected");
409   if (ArgTy.isDoubleTy())
410     return ArgKind::Double;
411   // FIXME: what about half?
412   IGC_ASSERT_MESSAGE(ArgTy.isFloatTy(), "unexpected floating point type");
413   return ArgKind::Float;
414 }
415 
getPointerArgKind(Type & ArgTy,PrintfArgInfo Info)416 static ArgKind::Enum getPointerArgKind(Type &ArgTy, PrintfArgInfo Info) {
417   IGC_ASSERT_MESSAGE(ArgTy.isPointerTy(),
418                      "wrong argument: pointer type was expected");
419   IGC_ASSERT_MESSAGE(Info.Type == PrintfArgInfo::Pointer ||
420                          Info.Type == PrintfArgInfo::String,
421                      "only %s and %p should correspond to pointer argument");
422   (void)ArgTy;
423   if (Info.Type == PrintfArgInfo::String)
424     return ArgKind::String;
425   return ArgKind::Pointer;
426 }
427 
getArgKind(Type & ArgTy,PrintfArgInfo Info)428 static ArgKind::Enum getArgKind(Type &ArgTy, PrintfArgInfo Info) {
429   if (ArgTy.isIntegerTy())
430     return getIntegerArgKind(ArgTy);
431   if (ArgTy.isFloatingPointTy())
432     return getFloatingPointArgKind(ArgTy);
433   return getPointerArgKind(ArgTy, Info);
434 }
435 
436 // sizeof(<2 x i32>) == 64
437 static constexpr unsigned VecArgSize = 64;
438 static constexpr auto VecArgElementSize = VecArgSize / ArgData::Size;
439 
440 // Casts Arg to <2 x i32> vector. For pointers ptrtoint i64 should be generated
441 // first.
get64BitArgAsVector(Value & Arg,IRBuilder<> & IRB,const DataLayout & DL)442 Value &get64BitArgAsVector(Value &Arg, IRBuilder<> &IRB, const DataLayout &DL) {
443   IGC_ASSERT_MESSAGE(DL.getTypeSizeInBits(Arg.getType()) == 64,
444                      "64-bit argument was expected");
445   auto *VecArgTy =
446       IGCLLVM::FixedVectorType::get(IRB.getInt32Ty(), ArgData::Size);
447   Value *ArgToBitCast = &Arg;
448   if (Arg.getType()->isPointerTy())
449     ArgToBitCast =
450         IRB.CreatePtrToInt(&Arg, IRB.getInt64Ty(), Arg.getName() + ".arg.p2i");
451   return *IRB.CreateBitCast(ArgToBitCast, VecArgTy, Arg.getName() + ".arg.bc");
452 }
453 
454 // Just creates this instruction:
455 // insertelement <2 x i32> zeroinitializer, i32 %arg, i32 0
456 // \p Arg must be i32 type.
get32BitIntArgAsVector(Value & Arg,IRBuilder<> & IRB,const DataLayout & DL)457 Value &get32BitIntArgAsVector(Value &Arg, IRBuilder<> &IRB,
458                               const DataLayout &DL) {
459   IGC_ASSERT_MESSAGE(Arg.getType()->isIntegerTy(32),
460                      "i32 argument was expected");
461   auto *VecArgTy =
462       IGCLLVM::FixedVectorType::get(IRB.getInt32Ty(), ArgData::Size);
463   auto *BlankVec = ConstantAggregateZero::get(VecArgTy);
464   return *IRB.CreateInsertElement(BlankVec, &Arg, IRB.getInt32(0),
465                                   Arg.getName() + ".arg.insert");
466 }
467 
468 // Takes arg that is not greater than 32 bit and casts it to i32 with possible
469 // zero extension.
getArgAs32BitInt(Value & Arg,IRBuilder<> & IRB,const DataLayout & DL)470 static Value &getArgAs32BitInt(Value &Arg, IRBuilder<> &IRB,
471                                const DataLayout &DL) {
472   auto ArgSize = DL.getTypeSizeInBits(Arg.getType());
473   IGC_ASSERT_MESSAGE(ArgSize <= VecArgElementSize,
474                      "argument isn't expected to be greater than 32 bit");
475   if (ArgSize < VecArgElementSize) {
476     // FIXME: seems like there may be some problems with signed types, depending
477     // on our BiF and runtime implementation.
478     // FIXME: What about half?
479     IGC_ASSERT_MESSAGE(Arg.getType()->isIntegerTy(),
480                        "only integers are expected to be less than 32 bits");
481     return *IRB.CreateZExt(&Arg, IRB.getInt32Ty(), Arg.getName() + ".arg.zext");
482   }
483   if (Arg.getType()->isPointerTy())
484     return *IRB.CreatePtrToInt(&Arg, IRB.getInt32Ty(),
485                                Arg.getName() + ".arg.p2i");
486   if (!Arg.getType()->isIntegerTy())
487     return *IRB.CreateBitCast(&Arg, IRB.getInt32Ty(),
488                               Arg.getName() + ".arg.bc");
489   return Arg;
490 }
491 
492 // Args are passed via <2 x i32> vector. This function casts \p Arg to this
493 // vector type. \p Arg is zext if necessary (zext in common sense - writing
494 // top element of a vector with zeros is zero extending too).
getArgAsVector(Value & Arg,IRBuilder<> & IRB,const DataLayout & DL)495 static Value &getArgAsVector(Value &Arg, IRBuilder<> &IRB,
496                              const DataLayout &DL) {
497   IGC_ASSERT_MESSAGE(!isa<IGCLLVM::FixedVectorType>(Arg.getType()),
498                      "scalar type is expected");
499   auto ArgSize = DL.getTypeSizeInBits(Arg.getType());
500 
501   if (ArgSize == VecArgSize)
502     return get64BitArgAsVector(Arg, IRB, DL);
503   IGC_ASSERT_MESSAGE(ArgSize < VecArgSize,
504                      "arg is expected to be not greater than 64 bit");
505   Value &Arg32Bit = getArgAs32BitInt(Arg, IRB, DL);
506   return get32BitIntArgAsVector(Arg32Bit, IRB, DL);
507 }
508 
509 // Create call to printf argument handler implementation for string argument
510 // (%s). Strings require a separate implementation.
createPrintfArgStrCall(CallInst & OrigPrintf,CallInst & PrevCall,Value & Arg)511 CallInst &GenXPrintfResolution::createPrintfArgStrCall(CallInst &OrigPrintf,
512                                                        CallInst &PrevCall,
513                                                        Value &Arg) {
514   assertPrintfCall(OrigPrintf);
515   IRBuilder<> IRB{&OrigPrintf};
516   auto StrAS = cast<PointerType>(Arg.getType())->getAddressSpace();
517   if (StrAS == FormatStringAddrSpace)
518     return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::ArgStr],
519                            {&PrevCall, &Arg},
520                            OrigPrintf.getName() + ".printf.arg");
521   IGC_ASSERT_MESSAGE(StrAS == LegacyFormatStringAddrSpace,
522                      "unexpected address space for a string argument");
523   return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::ArgStrLegacy],
524                          {&PrevCall, &Arg},
525                          OrigPrintf.getName() + ".printf.arg");
526 }
527 
createPrintfArgCall(CallInst & OrigPrintf,CallInst & PrevCall,Value & Arg,PrintfArgInfo Info)528 CallInst &GenXPrintfResolution::createPrintfArgCall(CallInst &OrigPrintf,
529                                                     CallInst &PrevCall,
530                                                     Value &Arg,
531                                                     PrintfArgInfo Info) {
532   assertPrintfCall(OrigPrintf);
533   ArgKind::Enum Kind = getArgKind(*Arg.getType(), Info);
534   IRBuilder<> IRB{&OrigPrintf};
535   if (Kind == ArgKind::String)
536     return createPrintfArgStrCall(OrigPrintf, PrevCall, Arg);
537   Value &ArgVec = getArgAsVector(Arg, IRB, *DL);
538   return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::Arg],
539                          {&PrevCall, IRB.getInt32(Kind), &ArgVec},
540                          OrigPrintf.getName() + ".printf.arg");
541 }
542 
createPrintfRetCall(CallInst & OrigPrintf,CallInst & PrevCall)543 CallInst &GenXPrintfResolution::createPrintfRetCall(CallInst &OrigPrintf,
544                                                     CallInst &PrevCall) {
545   assertPrintfCall(OrigPrintf);
546   IRBuilder<> IRB{&OrigPrintf};
547   return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::Ret], &PrevCall);
548 }
549