1 //===- AMDGPUEmitPrintf.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility function to lower a printf call into a series of device
10 // library calls on the AMDGPU target.
11 //
12 // WARNING: This file knows about certain library functions. It recognizes them
13 // by name, and hardwires knowledge of their semantics.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
18 #include "llvm/ADT/SparseBitVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/Support/DataExtractor.h"
22 #include "llvm/Support/MD5.h"
23 #include "llvm/Support/MathExtras.h"
24 
25 using namespace llvm;
26 
27 #define DEBUG_TYPE "amdgpu-emit-printf"
28 
29 static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) {
30   auto Int64Ty = Builder.getInt64Ty();
31   auto Ty = Arg->getType();
32 
33   if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
34     switch (IntTy->getBitWidth()) {
35     case 32:
36       return Builder.CreateZExt(Arg, Int64Ty);
37     case 64:
38       return Arg;
39     }
40   }
41 
42   if (Ty->getTypeID() == Type::DoubleTyID) {
43     return Builder.CreateBitCast(Arg, Int64Ty);
44   }
45 
46   if (isa<PointerType>(Ty)) {
47     return Builder.CreatePtrToInt(Arg, Int64Ty);
48   }
49 
50   llvm_unreachable("unexpected type");
51 }
52 
53 static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) {
54   auto Int64Ty = Builder.getInt64Ty();
55   auto M = Builder.GetInsertBlock()->getModule();
56   auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty);
57   return Builder.CreateCall(Fn, Version);
58 }
59 
60 static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs,
61                              Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3,
62                              Value *Arg4, Value *Arg5, Value *Arg6,
63                              bool IsLast) {
64   auto Int64Ty = Builder.getInt64Ty();
65   auto Int32Ty = Builder.getInt32Ty();
66   auto M = Builder.GetInsertBlock()->getModule();
67   auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty,
68                                    Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty,
69                                    Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty);
70   auto IsLastValue = Builder.getInt32(IsLast);
71   auto NumArgsValue = Builder.getInt32(NumArgs);
72   return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3,
73                                  Arg4, Arg5, Arg6, IsLastValue});
74 }
75 
76 static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
77                         bool IsLast) {
78   auto Arg0 = fitArgInto64Bits(Builder, Arg);
79   auto Zero = Builder.getInt64(0);
80   return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero,
81                         Zero, IsLast);
82 }
83 
84 // The device library does not provide strlen, so we build our own loop
85 // here. While we are at it, we also include the terminating null in the length.
86 static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) {
87   auto *Prev = Builder.GetInsertBlock();
88   Module *M = Prev->getModule();
89 
90   auto CharZero = Builder.getInt8(0);
91   auto One = Builder.getInt64(1);
92   auto Zero = Builder.getInt64(0);
93   auto Int64Ty = Builder.getInt64Ty();
94 
95   // The length is either zero for a null pointer, or the computed value for an
96   // actual string. We need a join block for a phi that represents the final
97   // value.
98   //
99   //  Strictly speaking, the zero does not matter since
100   // __ockl_printf_append_string_n ignores the length if the pointer is null.
101   BasicBlock *Join = nullptr;
102   if (Prev->getTerminator()) {
103     Join = Prev->splitBasicBlock(Builder.GetInsertPoint(),
104                                  "strlen.join");
105     Prev->getTerminator()->eraseFromParent();
106   } else {
107     Join = BasicBlock::Create(M->getContext(), "strlen.join",
108                               Prev->getParent());
109   }
110   BasicBlock *While =
111       BasicBlock::Create(M->getContext(), "strlen.while",
112                          Prev->getParent(), Join);
113   BasicBlock *WhileDone = BasicBlock::Create(
114       M->getContext(), "strlen.while.done",
115       Prev->getParent(), Join);
116 
117   // Emit an early return for when the pointer is null.
118   Builder.SetInsertPoint(Prev);
119   auto CmpNull =
120       Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType()));
121   BranchInst::Create(Join, While, CmpNull, Prev);
122 
123   // Entry to the while loop.
124   Builder.SetInsertPoint(While);
125 
126   auto PtrPhi = Builder.CreatePHI(Str->getType(), 2);
127   PtrPhi->addIncoming(Str, Prev);
128   auto PtrNext = Builder.CreateGEP(Builder.getInt8Ty(), PtrPhi, One);
129   PtrPhi->addIncoming(PtrNext, While);
130 
131   // Condition for the while loop.
132   auto Data = Builder.CreateLoad(Builder.getInt8Ty(), PtrPhi);
133   auto Cmp = Builder.CreateICmpEQ(Data, CharZero);
134   Builder.CreateCondBr(Cmp, WhileDone, While);
135 
136   // Add one to the computed length.
137   Builder.SetInsertPoint(WhileDone, WhileDone->begin());
138   auto Begin = Builder.CreatePtrToInt(Str, Int64Ty);
139   auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty);
140   auto Len = Builder.CreateSub(End, Begin);
141   Len = Builder.CreateAdd(Len, One);
142 
143   // Final join.
144   BranchInst::Create(Join, WhileDone);
145   Builder.SetInsertPoint(Join, Join->begin());
146   auto LenPhi = Builder.CreatePHI(Len->getType(), 2);
147   LenPhi->addIncoming(Len, WhileDone);
148   LenPhi->addIncoming(Zero, Prev);
149 
150   return LenPhi;
151 }
152 
153 static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str,
154                                 Value *Length, bool isLast) {
155   auto Int64Ty = Builder.getInt64Ty();
156   auto PtrTy = Builder.getPtrTy();
157   auto Int32Ty = Builder.getInt32Ty();
158   auto M = Builder.GetInsertBlock()->getModule();
159   auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty,
160                                    Int64Ty, PtrTy, Int64Ty, Int32Ty);
161   auto IsLastInt32 = Builder.getInt32(isLast);
162   return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32});
163 }
164 
165 static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg,
166                            bool IsLast) {
167   auto Length = getStrlenWithNull(Builder, Arg);
168   return callAppendStringN(Builder, Desc, Arg, Length, IsLast);
169 }
170 
171 static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
172                          bool SpecIsCString, bool IsLast) {
173   if (SpecIsCString && isa<PointerType>(Arg->getType())) {
174     return appendString(Builder, Desc, Arg, IsLast);
175   }
176   // If the format specifies a string but the argument is not, the frontend will
177   // have printed a warning. We just rely on undefined behaviour and send the
178   // argument anyway.
179   return appendArg(Builder, Desc, Arg, IsLast);
180 }
181 
182 // Scan the format string to locate all specifiers, and mark the ones that
183 // specify a string, i.e, the "%s" specifier with optional '*' characters.
184 static void locateCStrings(SparseBitVector<8> &BV, StringRef Str) {
185   static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn";
186   size_t SpecPos = 0;
187   // Skip the first argument, the format string.
188   unsigned ArgIdx = 1;
189 
190   while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) {
191     if (Str[SpecPos + 1] == '%') {
192       SpecPos += 2;
193       continue;
194     }
195     auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos);
196     if (SpecEnd == StringRef::npos)
197       return;
198     auto Spec = Str.slice(SpecPos, SpecEnd + 1);
199     ArgIdx += Spec.count('*');
200     if (Str[SpecEnd] == 's') {
201       BV.set(ArgIdx);
202     }
203     SpecPos = SpecEnd + 1;
204     ++ArgIdx;
205   }
206 }
207 
208 // helper struct to package the string related data
209 struct StringData {
210   StringRef Str;
211   Value *RealSize = nullptr;
212   Value *AlignedSize = nullptr;
213   bool IsConst = true;
214 
215   StringData(StringRef ST, Value *RS, Value *AS, bool IC)
216       : Str(ST), RealSize(RS), AlignedSize(AS), IsConst(IC) {}
217 };
218 
219 // Calculates frame size required for current printf expansion and allocates
220 // space on printf buffer. Printf frame includes following contents
221 // [ ControlDWord , format string/Hash , Arguments (each aligned to 8 byte) ]
222 static Value *callBufferedPrintfStart(
223     IRBuilder<> &Builder, ArrayRef<Value *> Args, Value *Fmt,
224     bool isConstFmtStr, SparseBitVector<8> &SpecIsCString,
225     SmallVectorImpl<StringData> &StringContents, Value *&ArgSize) {
226   Module *M = Builder.GetInsertBlock()->getModule();
227   Value *NonConstStrLen = nullptr;
228   Value *LenWithNull = nullptr;
229   Value *LenWithNullAligned = nullptr;
230   Value *TempAdd = nullptr;
231 
232   // First 4 bytes to be reserved for control dword
233   size_t BufSize = 4;
234   if (isConstFmtStr)
235     // First 8 bytes of MD5 hash
236     BufSize += 8;
237   else {
238     LenWithNull = getStrlenWithNull(Builder, Fmt);
239 
240     // Align the computed length to next 8 byte boundary
241     TempAdd = Builder.CreateAdd(LenWithNull,
242                                 ConstantInt::get(LenWithNull->getType(), 7U));
243     NonConstStrLen = Builder.CreateAnd(
244         TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));
245 
246     StringContents.push_back(
247         StringData(StringRef(), LenWithNull, NonConstStrLen, false));
248   }
249 
250   for (size_t i = 1; i < Args.size(); i++) {
251     if (SpecIsCString.test(i)) {
252       StringRef ArgStr;
253       if (getConstantStringInfo(Args[i], ArgStr)) {
254         auto alignedLen = alignTo(ArgStr.size() + 1, 8);
255         StringContents.push_back(StringData(
256             ArgStr,
257             /*RealSize*/ nullptr, /*AlignedSize*/ nullptr, /*IsConst*/ true));
258         BufSize += alignedLen;
259       } else {
260         LenWithNull = getStrlenWithNull(Builder, Args[i]);
261 
262         // Align the computed length to next 8 byte boundary
263         TempAdd = Builder.CreateAdd(
264             LenWithNull, ConstantInt::get(LenWithNull->getType(), 7U));
265         LenWithNullAligned = Builder.CreateAnd(
266             TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));
267 
268         if (NonConstStrLen) {
269           auto Val = Builder.CreateAdd(LenWithNullAligned, NonConstStrLen,
270                                        "cumulativeAdd");
271           NonConstStrLen = Val;
272         } else
273           NonConstStrLen = LenWithNullAligned;
274 
275         StringContents.push_back(
276             StringData(StringRef(), LenWithNull, LenWithNullAligned, false));
277       }
278     } else {
279       int AllocSize = M->getDataLayout().getTypeAllocSize(Args[i]->getType());
280       // We end up expanding non string arguments to 8 bytes
281       // (args smaller than 8 bytes)
282       BufSize += std::max(AllocSize, 8);
283     }
284   }
285 
286   // calculate final size value to be passed to printf_alloc
287   Value *SizeToReserve = ConstantInt::get(Builder.getInt64Ty(), BufSize, false);
288   SmallVector<Value *, 1> Alloc_args;
289   if (NonConstStrLen)
290     SizeToReserve = Builder.CreateAdd(NonConstStrLen, SizeToReserve);
291 
292   ArgSize = Builder.CreateTrunc(SizeToReserve, Builder.getInt32Ty());
293   Alloc_args.push_back(ArgSize);
294 
295   // call the printf_alloc function
296   AttributeList Attr = AttributeList::get(
297       Builder.getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind);
298 
299   Type *Tys_alloc[1] = {Builder.getInt32Ty()};
300   Type *PtrTy =
301       Builder.getPtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace());
302   FunctionType *FTy_alloc = FunctionType::get(PtrTy, Tys_alloc, false);
303   auto PrintfAllocFn =
304       M->getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr);
305 
306   return Builder.CreateCall(PrintfAllocFn, Alloc_args, "printf_alloc_fn");
307 }
308 
309 // Prepare constant string argument to push onto the buffer
310 static void processConstantStringArg(StringData *SD, IRBuilder<> &Builder,
311                                      SmallVectorImpl<Value *> &WhatToStore) {
312   std::string Str(SD->Str.str() + '\0');
313 
314   DataExtractor Extractor(Str, /*IsLittleEndian=*/true, 8);
315   DataExtractor::Cursor Offset(0);
316   while (Offset && Offset.tell() < Str.size()) {
317     const uint64_t ReadSize = 4;
318     uint64_t ReadNow = std::min(ReadSize, Str.size() - Offset.tell());
319     uint64_t ReadBytes = 0;
320     switch (ReadNow) {
321     default:
322       llvm_unreachable("min(4, X) > 4?");
323     case 1:
324       ReadBytes = Extractor.getU8(Offset);
325       break;
326     case 2:
327       ReadBytes = Extractor.getU16(Offset);
328       break;
329     case 3:
330       ReadBytes = Extractor.getU24(Offset);
331       break;
332     case 4:
333       ReadBytes = Extractor.getU32(Offset);
334       break;
335     }
336     cantFail(Offset.takeError(), "failed to read bytes from constant array");
337 
338     APInt IntVal(8 * ReadSize, ReadBytes);
339 
340     // TODO: Should not bother aligning up.
341     if (ReadNow < ReadSize)
342       IntVal = IntVal.zext(8 * ReadSize);
343 
344     Type *IntTy = Type::getIntNTy(Builder.getContext(), IntVal.getBitWidth());
345     WhatToStore.push_back(ConstantInt::get(IntTy, IntVal));
346   }
347   // Additional padding for 8 byte alignment
348   int Rem = (Str.size() % 8);
349   if (Rem > 0 && Rem <= 4)
350     WhatToStore.push_back(ConstantInt::get(Builder.getInt32Ty(), 0));
351 }
352 
353 static Value *processNonStringArg(Value *Arg, IRBuilder<> &Builder) {
354   const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout();
355   auto Ty = Arg->getType();
356 
357   if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
358     if (IntTy->getBitWidth() < 64) {
359       return Builder.CreateZExt(Arg, Builder.getInt64Ty());
360     }
361   }
362 
363   if (Ty->isFloatingPointTy()) {
364     if (DL.getTypeAllocSize(Ty) < 8) {
365       return Builder.CreateFPExt(Arg, Builder.getDoubleTy());
366     }
367   }
368 
369   return Arg;
370 }
371 
372 static void
373 callBufferedPrintfArgPush(IRBuilder<> &Builder, ArrayRef<Value *> Args,
374                           Value *PtrToStore, SparseBitVector<8> &SpecIsCString,
375                           SmallVectorImpl<StringData> &StringContents,
376                           bool IsConstFmtStr) {
377   Module *M = Builder.GetInsertBlock()->getModule();
378   const DataLayout &DL = M->getDataLayout();
379   auto StrIt = StringContents.begin();
380   size_t i = IsConstFmtStr ? 1 : 0;
381   for (; i < Args.size(); i++) {
382     SmallVector<Value *, 32> WhatToStore;
383     if ((i == 0) || SpecIsCString.test(i)) {
384       if (StrIt->IsConst) {
385         processConstantStringArg(StrIt, Builder, WhatToStore);
386         StrIt++;
387       } else {
388         // This copies the contents of the string, however the next offset
389         // is at aligned length, the extra space that might be created due
390         // to alignment padding is not populated with any specific value
391         // here. This would be safe as long as runtime is sync with
392         // the offsets.
393         Builder.CreateMemCpy(PtrToStore, /*DstAlign*/ Align(1), Args[i],
394                              /*SrcAlign*/ Args[i]->getPointerAlignment(DL),
395                              StrIt->RealSize);
396 
397         PtrToStore =
398             Builder.CreateInBoundsGEP(Builder.getInt8Ty(), PtrToStore,
399                                       {StrIt->AlignedSize}, "PrintBuffNextPtr");
400         LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:"
401                           << *PtrToStore << '\n');
402 
403         // done with current argument, move to next
404         StrIt++;
405         continue;
406       }
407     } else {
408       WhatToStore.push_back(processNonStringArg(Args[i], Builder));
409     }
410 
411     for (unsigned I = 0, E = WhatToStore.size(); I != E; ++I) {
412       Value *toStore = WhatToStore[I];
413 
414       StoreInst *StBuff = Builder.CreateStore(toStore, PtrToStore);
415       LLVM_DEBUG(dbgs() << "inserting store to printf buffer:" << *StBuff
416                         << '\n');
417       (void)StBuff;
418       PtrToStore = Builder.CreateConstInBoundsGEP1_32(
419           Builder.getInt8Ty(), PtrToStore,
420           M->getDataLayout().getTypeAllocSize(toStore->getType()),
421           "PrintBuffNextPtr");
422       LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:" << *PtrToStore
423                         << '\n');
424     }
425   }
426 }
427 
428 Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder, ArrayRef<Value *> Args,
429                                   bool IsBuffered) {
430   auto NumOps = Args.size();
431   assert(NumOps >= 1);
432 
433   auto Fmt = Args[0];
434   SparseBitVector<8> SpecIsCString;
435   StringRef FmtStr;
436 
437   if (getConstantStringInfo(Fmt, FmtStr))
438     locateCStrings(SpecIsCString, FmtStr);
439 
440   if (IsBuffered) {
441     SmallVector<StringData, 8> StringContents;
442     Module *M = Builder.GetInsertBlock()->getModule();
443     LLVMContext &Ctx = Builder.getContext();
444     auto Int8Ty = Builder.getInt8Ty();
445     auto Int32Ty = Builder.getInt32Ty();
446     bool IsConstFmtStr = !FmtStr.empty();
447 
448     Value *ArgSize = nullptr;
449     Value *Ptr =
450         callBufferedPrintfStart(Builder, Args, Fmt, IsConstFmtStr,
451                                 SpecIsCString, StringContents, ArgSize);
452 
453     // The buffered version still follows OpenCL printf standards for
454     // printf return value, i.e 0 on success, -1 on failure.
455     ConstantPointerNull *zeroIntPtr =
456         ConstantPointerNull::get(cast<PointerType>(Ptr->getType()));
457 
458     auto *Cmp = cast<ICmpInst>(Builder.CreateICmpNE(Ptr, zeroIntPtr, ""));
459 
460     BasicBlock *End = BasicBlock::Create(Ctx, "end.block",
461                                          Builder.GetInsertBlock()->getParent());
462     BasicBlock *ArgPush = BasicBlock::Create(
463         Ctx, "argpush.block", Builder.GetInsertBlock()->getParent());
464 
465     BranchInst::Create(ArgPush, End, Cmp, Builder.GetInsertBlock());
466     Builder.SetInsertPoint(ArgPush);
467 
468     // Create controlDWord and store as the first entry, format as follows
469     // Bit 0 (LSB) -> stream (1 if stderr, 0 if stdout, printf always outputs to
470     // stdout) Bit 1 -> constant format string (1 if constant) Bits 2-31 -> size
471     // of printf data frame
472     auto ConstantTwo = Builder.getInt32(2);
473     auto ControlDWord = Builder.CreateShl(ArgSize, ConstantTwo);
474     if (IsConstFmtStr)
475       ControlDWord = Builder.CreateOr(ControlDWord, ConstantTwo);
476 
477     Builder.CreateStore(ControlDWord, Ptr);
478 
479     Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 4);
480 
481     // Create MD5 hash for costant format string, push low 64 bits of the
482     // same onto buffer and metadata.
483     NamedMDNode *metaD = M->getOrInsertNamedMetadata("llvm.printf.fmts");
484     if (IsConstFmtStr) {
485       MD5 Hasher;
486       MD5::MD5Result Hash;
487       Hasher.update(FmtStr);
488       Hasher.final(Hash);
489 
490       // Try sticking to llvm.printf.fmts format, although we are not going to
491       // use the ID and argument size fields while printing,
492       std::string MetadataStr =
493           "0:0:" + llvm::utohexstr(Hash.low(), /*LowerCase=*/true) + "," +
494           FmtStr.str();
495       MDString *fmtStrArray = MDString::get(Ctx, MetadataStr);
496       MDNode *myMD = MDNode::get(Ctx, fmtStrArray);
497       metaD->addOperand(myMD);
498 
499       Builder.CreateStore(Builder.getInt64(Hash.low()), Ptr);
500       Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 8);
501     } else {
502       // Include a dummy metadata instance in case of only non constant
503       // format string usage, This might be an absurd usecase but needs to
504       // be done for completeness
505       if (metaD->getNumOperands() == 0) {
506         MDString *fmtStrArray =
507             MDString::get(Ctx, "0:0:ffffffff,\"Non const format string\"");
508         MDNode *myMD = MDNode::get(Ctx, fmtStrArray);
509         metaD->addOperand(myMD);
510       }
511     }
512 
513     // Push The printf arguments onto buffer
514     callBufferedPrintfArgPush(Builder, Args, Ptr, SpecIsCString, StringContents,
515                               IsConstFmtStr);
516 
517     // End block, returns -1 on failure
518     BranchInst::Create(End, ArgPush);
519     Builder.SetInsertPoint(End);
520     return Builder.CreateSExt(Builder.CreateNot(Cmp), Int32Ty, "printf_result");
521   }
522 
523   auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0));
524   Desc = appendString(Builder, Desc, Fmt, NumOps == 1);
525 
526   // FIXME: This invokes hostcall once for each argument. We can pack up to
527   // seven scalar printf arguments in a single hostcall. See the signature of
528   // callAppendArgs().
529   for (unsigned int i = 1; i != NumOps; ++i) {
530     bool IsLast = i == NumOps - 1;
531     bool IsCString = SpecIsCString.test(i);
532     Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast);
533   }
534 
535   return Builder.CreateTrunc(Desc, Builder.getInt32Ty());
536 }
537