1 /*========================== begin_copyright_notice ============================
2
3 Copyright (C) 2021 Intel Corporation
4
5 SPDX-License-Identifier: MIT
6
7 ============================= end_copyright_notice ===========================*/
8
9 //
10 /// GenXPrintfResolution
11 /// --------------------
12 /// This pass finds every call to printf function and replaces it with a series
13 /// of printf implementation functions from BiF. A proper version of
14 /// implementation (32/64 bit, cm/ocl/ze binary) is provided by outer logic.
15 /// Before:
16 /// %p = call spir_func i32 (i8 as(2)*, ...) @printf(i8 as(2)* %str)
17 /// After:
18 /// %init = call <4 x i32> @__vc_printf_init(<4 x i32> zeroinitializer)
19 /// %fmt = call <4 x i32> @__vc_printf_fmt(<4 x i32> %init, i8 as(2)* %str)
20 /// %printf = call i32 @__vc_printf_ret(<4 x i32> %fmt)
21 ///
22 /// Vector <4 x i32> is passed between functions to transfer their internal
23 /// data. This data is handled by the function implementations themselves,
24 /// this pass knows nothing about it.
25 //===----------------------------------------------------------------------===//
26
27 #include "vc/BiF/PrintfIface.h"
28 #include "vc/BiF/Tools.h"
29 #include "vc/GenXOpts/GenXOpts.h"
30 #include "vc/Support/BackendConfig.h"
31 #include "vc/Utils/GenX/Printf.h"
32 #include "vc/Utils/General/BiF.h"
33 #include "vc/Utils/General/Types.h"
34
35 #include <llvm/ADT/STLExtras.h>
36 #include <llvm/ADT/iterator_range.h>
37 #include <llvm/IR/Constants.h>
38 #include <llvm/IR/DataLayout.h>
39 #include <llvm/IR/IRBuilder.h>
40 #include <llvm/IR/InstIterator.h>
41 #include <llvm/IR/Instructions.h>
42 #include <llvm/IR/Module.h>
43 #include <llvm/Linker/Linker.h>
44 #include <llvm/Pass.h>
45 #include <llvm/Support/ErrorHandling.h>
46
47 #include "llvmWrapper/IR/DerivedTypes.h"
48
49 #include <algorithm>
50 #include <functional>
51 #include <numeric>
52 #include <sstream>
53 #include <vector>
54
55 using namespace llvm;
56 using namespace vc;
57 using namespace vc::bif::printf;
58
59 namespace PrintfImplFunc {
60 enum Enum { Init, Fmt, FmtLegacy, Arg, ArgStr, ArgStrLegacy, Ret, Size };
61 static constexpr const char *Name[Size] = {
62 "__vc_printf_init", "__vc_printf_fmt", "__vc_printf_fmt_legacy",
63 "__vc_printf_arg", "__vc_printf_arg_str", "__vc_printf_arg_str_legacy",
64 "__vc_printf_ret"};
65 } // namespace PrintfImplFunc
66
67 static constexpr int FormatStringAddrSpace = vc::AddrSpace::Constant;
68 static constexpr int LegacyFormatStringAddrSpace = vc::AddrSpace::Private;
69
70 namespace {
71 class GenXPrintfResolution final : public ModulePass {
72 const DataLayout *DL = nullptr;
73 std::array<FunctionCallee, PrintfImplFunc::Size> PrintfImplDecl;
74
75 public:
76 static char ID;
GenXPrintfResolution()77 GenXPrintfResolution() : ModulePass(ID) {}
getPassName() const78 StringRef getPassName() const override { return "GenX printf resolution"; }
79 void getAnalysisUsage(AnalysisUsage &AU) const override;
80 bool runOnModule(Module &M) override;
81
82 private:
83 std::unique_ptr<Module> getBiFModule(LLVMContext &Ctx);
84 void handlePrintfCall(CallInst &OrigPrintf);
85 void addPrintfImplDeclarations(Module &M);
86 void updatePrintfImplDeclarations(Module &M);
87 void preparePrintfImplForInlining();
88 CallInst &createPrintfInitCall(CallInst &OrigPrintf, int FmtStrSize,
89 const PrintfArgInfoSeq &ArgsInfo);
90 CallInst &createPrintfFmtCall(CallInst &OrigPrintf, CallInst &InitCall);
91 CallInst &createPrintfArgCall(CallInst &OrigPrintf, CallInst &PrevCall,
92 Value &Arg, PrintfArgInfo Info);
93 CallInst &createPrintfArgStrCall(CallInst &OrigPrintf, CallInst &PrevCall,
94 Value &Arg);
95 CallInst &createPrintfRetCall(CallInst &OrigPrintf, CallInst &PrevCall);
96 };
97 } // namespace
98
99 char GenXPrintfResolution::ID = 0;
100 namespace llvm {
101 void initializeGenXPrintfResolutionPass(PassRegistry &);
102 }
103
104 INITIALIZE_PASS_BEGIN(GenXPrintfResolution, "GenXPrintfResolution",
105 "GenXPrintfResolution", false, false)
INITIALIZE_PASS_DEPENDENCY(GenXBackendConfig)106 INITIALIZE_PASS_DEPENDENCY(GenXBackendConfig)
107 INITIALIZE_PASS_END(GenXPrintfResolution, "GenXPrintfResolution",
108 "GenXPrintfResolution", false, false)
109
110 ModulePass *llvm::createGenXPrintfResolutionPass() {
111 initializeGenXPrintfResolutionPass(*PassRegistry::getPassRegistry());
112 return new GenXPrintfResolution;
113 }
114
getAnalysisUsage(AnalysisUsage & AU) const115 void GenXPrintfResolution::getAnalysisUsage(AnalysisUsage &AU) const {
116 AU.addRequired<GenXBackendConfig>();
117 }
118
119 using CallInstRef = std::reference_wrapper<CallInst>;
120
isPrintfCall(const CallInst & CI)121 static bool isPrintfCall(const CallInst &CI) {
122 auto *CalledFunc = CI.getCalledFunction();
123 if (!CalledFunc)
124 return false;
125 if (!CalledFunc->isDeclaration())
126 return false;
127 return CalledFunc->getName() == "printf" ||
128 CalledFunc->getName().contains("__spirv_ocl_printf");
129 }
130
isPrintfCall(const Instruction & Inst)131 static bool isPrintfCall(const Instruction &Inst) {
132 if (!isa<CallInst>(Inst))
133 return false;
134 return isPrintfCall(cast<CallInst>(Inst));
135 }
136
collectWorkload(Module & M)137 static std::vector<CallInstRef> collectWorkload(Module &M) {
138 std::vector<CallInstRef> Workload;
139 for (Function &F : M)
140 llvm::transform(
141 make_filter_range(instructions(F),
142 [](Instruction &Inst) { return isPrintfCall(Inst); }),
143 std::back_inserter(Workload),
144 [](Instruction &Inst) { return std::ref(cast<CallInst>(Inst)); });
145 return Workload;
146 }
147
runOnModule(Module & M)148 bool GenXPrintfResolution::runOnModule(Module &M) {
149 DL = &M.getDataLayout();
150
151 std::vector<CallInstRef> Workload = collectWorkload(M);
152 if (Workload.empty())
153 return false;
154 addPrintfImplDeclarations(M);
155 for (CallInst &CI : Workload)
156 handlePrintfCall(CI);
157
158 std::unique_ptr<Module> PrintfImplModule = getBiFModule(M.getContext());
159 PrintfImplModule->setDataLayout(M.getDataLayout());
160 PrintfImplModule->setTargetTriple(M.getTargetTriple());
161 if (Linker::linkModules(M, std::move(PrintfImplModule),
162 Linker::Flags::LinkOnlyNeeded)) {
163 IGC_ASSERT_MESSAGE(0, "Error linking printf implementation builtin module");
164 }
165 updatePrintfImplDeclarations(M);
166 preparePrintfImplForInlining();
167 return true;
168 }
169
getBiFModule(LLVMContext & Ctx)170 std::unique_ptr<Module> GenXPrintfResolution::getBiFModule(LLVMContext &Ctx) {
171 MemoryBufferRef PrintfBiFModuleBuffer =
172 getAnalysis<GenXBackendConfig>().getBiFModule(BiFKind::VCPrintf);
173 if (!PrintfBiFModuleBuffer.getBufferSize()) {
174 IGC_ASSERT_MESSAGE(
175 vc::bif::disabled(),
176 "printf bif module can be empty only if vc bif was disabled");
177 report_fatal_error("printf implementation module is absent");
178 }
179 return vc::getBiFModuleOrReportError(PrintfBiFModuleBuffer, Ctx);
180 }
181
assertPrintfCall(const CallInst & CI)182 static void assertPrintfCall(const CallInst &CI) {
183 IGC_ASSERT_MESSAGE(isPrintfCall(CI), "printf call is expected");
184 IGC_ASSERT_MESSAGE(CI.arg_size() > 0,
185 "printf call must have at least format string argument");
186 (void)CI;
187 }
188
189 // Returns pair of format string size (including '\0') and argument information.
190 static std::pair<int, PrintfArgInfoSeq>
analyzeFormatString(const Value & FmtStrOp)191 analyzeFormatString(const Value &FmtStrOp) {
192 auto FmtStr = getConstStringFromOperandOptional(FmtStrOp);
193 if (!FmtStr)
194 report_fatal_error(
195 "printf resolution cannot access format string during compile time");
196 return {FmtStr.getValue().size() + 1, parseFormatString(FmtStr.getValue())};
197 }
198
199 // Marks strings passed as "%s" arguments in printf.
200 // Recursive function, long instruction chains aren't expected.
markStringArgument(Value & Arg)201 static void markStringArgument(Value &Arg) {
202 if (isa<GEPOperator>(Arg)) {
203 auto *String = getConstStringGVFromOperandOptional(Arg);
204 if (!String)
205 report_fatal_error(PrintfStringAccessError);
206 String->addAttribute(PrintfStringVariable);
207 return;
208 }
209 if (isa<SelectInst>(Arg)) {
210 auto &SI = cast<SelectInst>(Arg);
211 // The same value can be potentially accessed by different paths. Though
212 // it is probably OK, since the same string can be marked several times
213 // and the most of the time cases would be simple so it is not that
214 // critical to pass same values several times in some rare complicated
215 // cases.
216 markStringArgument(*SI.getFalseValue());
217 markStringArgument(*SI.getTrueValue());
218 return;
219 }
220 // Only direct use and selection between strings is supported.
221 report_fatal_error(PrintfStringAccessError);
222 }
223
224 // Marks printf strings: format strings, strings passed as "%s" arguments.
markPrintfStrings(CallInst & OrigPrintf,const PrintfArgInfoSeq & ArgsInfo)225 static void markPrintfStrings(CallInst &OrigPrintf,
226 const PrintfArgInfoSeq &ArgsInfo) {
227 auto &FormatString =
228 getConstStringGVFromOperand(*OrigPrintf.getArgOperand(0));
229 FormatString.addAttribute(PrintfStringVariable);
230
231 // Handle string arguments (%s).
232 auto StringArgs = make_filter_range(
233 zip(drop_begin(OrigPrintf.args(), 1), ArgsInfo), [](auto &&ArgWithInfo) {
234 return std::get<const PrintfArgInfo &>(ArgWithInfo).Type ==
235 PrintfArgInfo::String;
236 });
237 for (auto &&[Arg, ArgInfo] : StringArgs)
238 markStringArgument(*Arg.get());
239 }
240
handlePrintfCall(CallInst & OrigPrintf)241 void GenXPrintfResolution::handlePrintfCall(CallInst &OrigPrintf) {
242 assertPrintfCall(OrigPrintf);
243 auto [FmtStrSize, ArgsInfo] =
244 analyzeFormatString(*OrigPrintf.getArgOperand(0));
245 if (ArgsInfo.size() != OrigPrintf.getNumArgOperands() - 1)
246 report_fatal_error("printf format string and arguments don't correspond");
247
248 markPrintfStrings(OrigPrintf, ArgsInfo);
249
250 auto &InitCall = createPrintfInitCall(OrigPrintf, FmtStrSize, ArgsInfo);
251 auto &FmtCall = createPrintfFmtCall(OrigPrintf, InitCall);
252
253 // FIXME: combine LLVM call args type and format string info in more
254 // intelligent way.
255 auto ArgsWithInfo = zip(ArgsInfo, drop_begin(OrigPrintf.args(), 1));
256 // potentially FmtCall as there may be no arguments
257 auto &LastArgCall = *std::accumulate(
258 ArgsWithInfo.begin(), ArgsWithInfo.end(), &FmtCall,
259 [&OrigPrintf, this](CallInst *PrevCall, auto &&ArgWithInfo) {
260 return &createPrintfArgCall(OrigPrintf, *PrevCall,
261 *std::get<Use &>(ArgWithInfo).get(),
262 std::get<PrintfArgInfo &>(ArgWithInfo));
263 });
264 auto &RetCall = createPrintfRetCall(OrigPrintf, LastArgCall);
265 RetCall.takeName(&OrigPrintf);
266 OrigPrintf.replaceAllUsesWith(&RetCall);
267 OrigPrintf.eraseFromParent();
268 }
269
270 using PrintfImplTypeStorage = std::array<FunctionType *, PrintfImplFunc::Size>;
271
getPrintfImplTypes(LLVMContext & Ctx)272 static PrintfImplTypeStorage getPrintfImplTypes(LLVMContext &Ctx) {
273 auto *TransferDataTy =
274 IGCLLVM::FixedVectorType::get(Type::getInt32Ty(Ctx), TransferDataSize);
275 auto *ArgsInfoTy =
276 IGCLLVM::FixedVectorType::get(Type::getInt32Ty(Ctx), ArgsInfoVector::Size);
277 auto *ArgDataTy = IGCLLVM::FixedVectorType::get(Type::getInt32Ty(Ctx), ArgData::Size);
278 constexpr bool IsVarArg = false;
279
280 PrintfImplTypeStorage FuncTys;
281 FuncTys[PrintfImplFunc::Init] =
282 FunctionType::get(TransferDataTy, ArgsInfoTy, IsVarArg);
283 FuncTys[PrintfImplFunc::Fmt] = FunctionType::get(
284 TransferDataTy,
285 {TransferDataTy,
286 PointerType::get(Type::getInt8Ty(Ctx), FormatStringAddrSpace)},
287 IsVarArg);
288 FuncTys[PrintfImplFunc::FmtLegacy] = FunctionType::get(
289 TransferDataTy,
290 {TransferDataTy,
291 PointerType::get(Type::getInt8Ty(Ctx), LegacyFormatStringAddrSpace)},
292 IsVarArg);
293 FuncTys[PrintfImplFunc::Arg] = FunctionType::get(
294 TransferDataTy, {TransferDataTy, Type::getInt32Ty(Ctx), ArgDataTy},
295 IsVarArg);
296 FuncTys[PrintfImplFunc::ArgStr] = FuncTys[PrintfImplFunc::Fmt];
297 FuncTys[PrintfImplFunc::ArgStrLegacy] = FuncTys[PrintfImplFunc::FmtLegacy];
298 FuncTys[PrintfImplFunc::Ret] =
299 FunctionType::get(Type::getInt32Ty(Ctx), TransferDataTy, IsVarArg);
300 return FuncTys;
301 }
302
addPrintfImplDeclarations(Module & M)303 void GenXPrintfResolution::addPrintfImplDeclarations(Module &M) {
304 auto PrintfImplTy = getPrintfImplTypes(M.getContext());
305
306 for (int FuncID = 0; FuncID != PrintfImplFunc::Size; ++FuncID)
307 PrintfImplDecl[FuncID] = M.getOrInsertFunction(PrintfImplFunc::Name[FuncID],
308 PrintfImplTy[FuncID]);
309 }
310
311 // The function must be internal and have always inline attribute for
312 // always-inline pass to inline it and remove the original function body
313 // (the both are critical for GenXPrintfLegalization to work correctly).
preparePrintfImplForInlining()314 void GenXPrintfResolution::preparePrintfImplForInlining() {
315 for (auto Callee : PrintfImplDecl) {
316 auto *Func = cast<Function>(Callee.getCallee());
317 Func->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
318 Func->addFnAttr(Attribute::AlwaysInline);
319 }
320 }
321
updatePrintfImplDeclarations(Module & M)322 void GenXPrintfResolution::updatePrintfImplDeclarations(Module &M) {
323 std::transform(
324 std::begin(PrintfImplFunc::Name), std::end(PrintfImplFunc::Name),
325 PrintfImplDecl.begin(),
326 [&M](const char *Name) -> FunctionCallee { return M.getFunction(Name); });
327 }
328
329 using ArgsInfoStorage = std::array<unsigned, ArgsInfoVector::Size>;
330
331 // Returns arguments information required by init implementation function.
332 // FIXME: combine LLVM call args type and format string info before this
333 // function.
collectArgsInfo(CallInst & OrigPrintf,int FmtStrSize,const PrintfArgInfoSeq & FmtArgsInfo)334 static ArgsInfoStorage collectArgsInfo(CallInst &OrigPrintf, int FmtStrSize,
335 const PrintfArgInfoSeq &FmtArgsInfo) {
336 assertPrintfCall(OrigPrintf);
337
338 ArgsInfoStorage ArgsInfo;
339 // It's not about format string.
340 ArgsInfo[ArgsInfoVector::NumTotal] = OrigPrintf.arg_size() - 1;
341 auto PrintfArgs =
342 make_range(std::next(OrigPrintf.arg_begin()), OrigPrintf.arg_end());
343
344 ArgsInfo[ArgsInfoVector::Num64Bit] =
345 llvm::count_if(PrintfArgs, [](Value *Arg) {
346 return Arg->getType()->getPrimitiveSizeInBits() == 64;
347 });
348 ArgsInfo[ArgsInfoVector::NumPtr] =
349 llvm::count_if(FmtArgsInfo, [](PrintfArgInfo Info) {
350 return Info.Type == PrintfArgInfo::Pointer;
351 });
352 ArgsInfo[ArgsInfoVector::NumStr] =
353 llvm::count_if(FmtArgsInfo, [](PrintfArgInfo Info) {
354 return Info.Type == PrintfArgInfo::String;
355 });
356 ArgsInfo[ArgsInfoVector::FormatStrSize] = FmtStrSize;
357 return ArgsInfo;
358 }
359
createPrintfInitCall(CallInst & OrigPrintf,int FmtStrSize,const PrintfArgInfoSeq & FmtArgsInfo)360 CallInst &GenXPrintfResolution::createPrintfInitCall(
361 CallInst &OrigPrintf, int FmtStrSize, const PrintfArgInfoSeq &FmtArgsInfo) {
362 assertPrintfCall(OrigPrintf);
363 auto ImplArgsInfo = collectArgsInfo(OrigPrintf, FmtStrSize, FmtArgsInfo);
364
365 IRBuilder<> IRB{&OrigPrintf};
366 auto *ArgsInfoV =
367 ConstantDataVector::get(OrigPrintf.getContext(), ImplArgsInfo);
368 return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::Init], ArgsInfoV,
369 OrigPrintf.getName() + ".printf.init");
370 }
371
createPrintfFmtCall(CallInst & OrigPrintf,CallInst & InitCall)372 CallInst &GenXPrintfResolution::createPrintfFmtCall(CallInst &OrigPrintf,
373 CallInst &InitCall) {
374 assertPrintfCall(OrigPrintf);
375 IRBuilder<> IRB{&OrigPrintf};
376 auto FmtAS =
377 cast<PointerType>(OrigPrintf.getOperand(0)->getType())->getAddressSpace();
378 if (FmtAS == FormatStringAddrSpace)
379 return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::Fmt],
380 {&InitCall, OrigPrintf.getOperand(0)},
381 OrigPrintf.getName() + ".printf.fmt");
382 IGC_ASSERT_MESSAGE(FmtAS == LegacyFormatStringAddrSpace,
383 "unexpected address space for format string");
384 return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::FmtLegacy],
385 {&InitCall, OrigPrintf.getOperand(0)},
386 OrigPrintf.getName() + ".printf.fmt");
387 }
388
getIntegerArgKind(Type & ArgTy)389 static ArgKind::Enum getIntegerArgKind(Type &ArgTy) {
390 IGC_ASSERT_MESSAGE(ArgTy.isIntegerTy(),
391 "wrong argument: integer type was expected");
392 auto BitWidth = ArgTy.getIntegerBitWidth();
393 switch (BitWidth) {
394 case 64:
395 return ArgKind::Long;
396 case 32:
397 return ArgKind::Int;
398 case 16:
399 return ArgKind::Short;
400 default:
401 IGC_ASSERT_MESSAGE(BitWidth == 8, "unexpected integer type");
402 return ArgKind::Char;
403 }
404 }
405
getFloatingPointArgKind(Type & ArgTy)406 static ArgKind::Enum getFloatingPointArgKind(Type &ArgTy) {
407 IGC_ASSERT_MESSAGE(ArgTy.isFloatingPointTy(),
408 "wrong argument: floating point type was expected");
409 if (ArgTy.isDoubleTy())
410 return ArgKind::Double;
411 // FIXME: what about half?
412 IGC_ASSERT_MESSAGE(ArgTy.isFloatTy(), "unexpected floating point type");
413 return ArgKind::Float;
414 }
415
getPointerArgKind(Type & ArgTy,PrintfArgInfo Info)416 static ArgKind::Enum getPointerArgKind(Type &ArgTy, PrintfArgInfo Info) {
417 IGC_ASSERT_MESSAGE(ArgTy.isPointerTy(),
418 "wrong argument: pointer type was expected");
419 IGC_ASSERT_MESSAGE(Info.Type == PrintfArgInfo::Pointer ||
420 Info.Type == PrintfArgInfo::String,
421 "only %s and %p should correspond to pointer argument");
422 (void)ArgTy;
423 if (Info.Type == PrintfArgInfo::String)
424 return ArgKind::String;
425 return ArgKind::Pointer;
426 }
427
getArgKind(Type & ArgTy,PrintfArgInfo Info)428 static ArgKind::Enum getArgKind(Type &ArgTy, PrintfArgInfo Info) {
429 if (ArgTy.isIntegerTy())
430 return getIntegerArgKind(ArgTy);
431 if (ArgTy.isFloatingPointTy())
432 return getFloatingPointArgKind(ArgTy);
433 return getPointerArgKind(ArgTy, Info);
434 }
435
436 // sizeof(<2 x i32>) == 64
437 static constexpr unsigned VecArgSize = 64;
438 static constexpr auto VecArgElementSize = VecArgSize / ArgData::Size;
439
440 // Casts Arg to <2 x i32> vector. For pointers ptrtoint i64 should be generated
441 // first.
get64BitArgAsVector(Value & Arg,IRBuilder<> & IRB,const DataLayout & DL)442 Value &get64BitArgAsVector(Value &Arg, IRBuilder<> &IRB, const DataLayout &DL) {
443 IGC_ASSERT_MESSAGE(DL.getTypeSizeInBits(Arg.getType()) == 64,
444 "64-bit argument was expected");
445 auto *VecArgTy =
446 IGCLLVM::FixedVectorType::get(IRB.getInt32Ty(), ArgData::Size);
447 Value *ArgToBitCast = &Arg;
448 if (Arg.getType()->isPointerTy())
449 ArgToBitCast =
450 IRB.CreatePtrToInt(&Arg, IRB.getInt64Ty(), Arg.getName() + ".arg.p2i");
451 return *IRB.CreateBitCast(ArgToBitCast, VecArgTy, Arg.getName() + ".arg.bc");
452 }
453
454 // Just creates this instruction:
455 // insertelement <2 x i32> zeroinitializer, i32 %arg, i32 0
456 // \p Arg must be i32 type.
get32BitIntArgAsVector(Value & Arg,IRBuilder<> & IRB,const DataLayout & DL)457 Value &get32BitIntArgAsVector(Value &Arg, IRBuilder<> &IRB,
458 const DataLayout &DL) {
459 IGC_ASSERT_MESSAGE(Arg.getType()->isIntegerTy(32),
460 "i32 argument was expected");
461 auto *VecArgTy =
462 IGCLLVM::FixedVectorType::get(IRB.getInt32Ty(), ArgData::Size);
463 auto *BlankVec = ConstantAggregateZero::get(VecArgTy);
464 return *IRB.CreateInsertElement(BlankVec, &Arg, IRB.getInt32(0),
465 Arg.getName() + ".arg.insert");
466 }
467
468 // Takes arg that is not greater than 32 bit and casts it to i32 with possible
469 // zero extension.
getArgAs32BitInt(Value & Arg,IRBuilder<> & IRB,const DataLayout & DL)470 static Value &getArgAs32BitInt(Value &Arg, IRBuilder<> &IRB,
471 const DataLayout &DL) {
472 auto ArgSize = DL.getTypeSizeInBits(Arg.getType());
473 IGC_ASSERT_MESSAGE(ArgSize <= VecArgElementSize,
474 "argument isn't expected to be greater than 32 bit");
475 if (ArgSize < VecArgElementSize) {
476 // FIXME: seems like there may be some problems with signed types, depending
477 // on our BiF and runtime implementation.
478 // FIXME: What about half?
479 IGC_ASSERT_MESSAGE(Arg.getType()->isIntegerTy(),
480 "only integers are expected to be less than 32 bits");
481 return *IRB.CreateZExt(&Arg, IRB.getInt32Ty(), Arg.getName() + ".arg.zext");
482 }
483 if (Arg.getType()->isPointerTy())
484 return *IRB.CreatePtrToInt(&Arg, IRB.getInt32Ty(),
485 Arg.getName() + ".arg.p2i");
486 if (!Arg.getType()->isIntegerTy())
487 return *IRB.CreateBitCast(&Arg, IRB.getInt32Ty(),
488 Arg.getName() + ".arg.bc");
489 return Arg;
490 }
491
492 // Args are passed via <2 x i32> vector. This function casts \p Arg to this
493 // vector type. \p Arg is zext if necessary (zext in common sense - writing
494 // top element of a vector with zeros is zero extending too).
getArgAsVector(Value & Arg,IRBuilder<> & IRB,const DataLayout & DL)495 static Value &getArgAsVector(Value &Arg, IRBuilder<> &IRB,
496 const DataLayout &DL) {
497 IGC_ASSERT_MESSAGE(!isa<IGCLLVM::FixedVectorType>(Arg.getType()),
498 "scalar type is expected");
499 auto ArgSize = DL.getTypeSizeInBits(Arg.getType());
500
501 if (ArgSize == VecArgSize)
502 return get64BitArgAsVector(Arg, IRB, DL);
503 IGC_ASSERT_MESSAGE(ArgSize < VecArgSize,
504 "arg is expected to be not greater than 64 bit");
505 Value &Arg32Bit = getArgAs32BitInt(Arg, IRB, DL);
506 return get32BitIntArgAsVector(Arg32Bit, IRB, DL);
507 }
508
509 // Create call to printf argument handler implementation for string argument
510 // (%s). Strings require a separate implementation.
createPrintfArgStrCall(CallInst & OrigPrintf,CallInst & PrevCall,Value & Arg)511 CallInst &GenXPrintfResolution::createPrintfArgStrCall(CallInst &OrigPrintf,
512 CallInst &PrevCall,
513 Value &Arg) {
514 assertPrintfCall(OrigPrintf);
515 IRBuilder<> IRB{&OrigPrintf};
516 auto StrAS = cast<PointerType>(Arg.getType())->getAddressSpace();
517 if (StrAS == FormatStringAddrSpace)
518 return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::ArgStr],
519 {&PrevCall, &Arg},
520 OrigPrintf.getName() + ".printf.arg");
521 IGC_ASSERT_MESSAGE(StrAS == LegacyFormatStringAddrSpace,
522 "unexpected address space for a string argument");
523 return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::ArgStrLegacy],
524 {&PrevCall, &Arg},
525 OrigPrintf.getName() + ".printf.arg");
526 }
527
createPrintfArgCall(CallInst & OrigPrintf,CallInst & PrevCall,Value & Arg,PrintfArgInfo Info)528 CallInst &GenXPrintfResolution::createPrintfArgCall(CallInst &OrigPrintf,
529 CallInst &PrevCall,
530 Value &Arg,
531 PrintfArgInfo Info) {
532 assertPrintfCall(OrigPrintf);
533 ArgKind::Enum Kind = getArgKind(*Arg.getType(), Info);
534 IRBuilder<> IRB{&OrigPrintf};
535 if (Kind == ArgKind::String)
536 return createPrintfArgStrCall(OrigPrintf, PrevCall, Arg);
537 Value &ArgVec = getArgAsVector(Arg, IRB, *DL);
538 return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::Arg],
539 {&PrevCall, IRB.getInt32(Kind), &ArgVec},
540 OrigPrintf.getName() + ".printf.arg");
541 }
542
createPrintfRetCall(CallInst & OrigPrintf,CallInst & PrevCall)543 CallInst &GenXPrintfResolution::createPrintfRetCall(CallInst &OrigPrintf,
544 CallInst &PrevCall) {
545 assertPrintfCall(OrigPrintf);
546 IRBuilder<> IRB{&OrigPrintf};
547 return *IRB.CreateCall(PrintfImplDecl[PrintfImplFunc::Ret], &PrevCall);
548 }
549