1 //===- AMDGPUEmitPrintf.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility function to lower a printf call into a series of device
10 // library calls on the AMDGPU target.
11 //
12 // WARNING: This file knows about certain library functions. It recognizes them
13 // by name, and hardwires knowledge of their semantics.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
18 #include "llvm/ADT/SparseBitVector.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 
21 using namespace llvm;
22 
23 #define DEBUG_TYPE "amdgpu-emit-printf"
24 
25 static bool isCString(const Value *Arg) {
26   auto Ty = Arg->getType();
27   auto PtrTy = dyn_cast<PointerType>(Ty);
28   if (!PtrTy)
29     return false;
30 
31   auto IntTy = dyn_cast<IntegerType>(PtrTy->getElementType());
32   if (!IntTy)
33     return false;
34 
35   return IntTy->getBitWidth() == 8;
36 }
37 
38 static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) {
39   auto Int64Ty = Builder.getInt64Ty();
40   auto Ty = Arg->getType();
41 
42   if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
43     switch (IntTy->getBitWidth()) {
44     case 32:
45       return Builder.CreateZExt(Arg, Int64Ty);
46     case 64:
47       return Arg;
48     }
49   }
50 
51   if (Ty->getTypeID() == Type::DoubleTyID) {
52     return Builder.CreateBitCast(Arg, Int64Ty);
53   }
54 
55   if (isa<PointerType>(Ty)) {
56     return Builder.CreatePtrToInt(Arg, Int64Ty);
57   }
58 
59   llvm_unreachable("unexpected type");
60 }
61 
62 static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) {
63   auto Int64Ty = Builder.getInt64Ty();
64   auto M = Builder.GetInsertBlock()->getModule();
65   auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty);
66   return Builder.CreateCall(Fn, Version);
67 }
68 
69 static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs,
70                              Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3,
71                              Value *Arg4, Value *Arg5, Value *Arg6,
72                              bool IsLast) {
73   auto Int64Ty = Builder.getInt64Ty();
74   auto Int32Ty = Builder.getInt32Ty();
75   auto M = Builder.GetInsertBlock()->getModule();
76   auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty,
77                                    Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty,
78                                    Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty);
79   auto IsLastValue = Builder.getInt32(IsLast);
80   auto NumArgsValue = Builder.getInt32(NumArgs);
81   return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3,
82                                  Arg4, Arg5, Arg6, IsLastValue});
83 }
84 
85 static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
86                         bool IsLast) {
87   auto Arg0 = fitArgInto64Bits(Builder, Arg);
88   auto Zero = Builder.getInt64(0);
89   return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero,
90                         Zero, IsLast);
91 }
92 
93 // The device library does not provide strlen, so we build our own loop
94 // here. While we are at it, we also include the terminating null in the length.
95 static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) {
96   auto *Prev = Builder.GetInsertBlock();
97   Module *M = Prev->getModule();
98 
99   auto CharZero = Builder.getInt8(0);
100   auto One = Builder.getInt64(1);
101   auto Zero = Builder.getInt64(0);
102   auto Int64Ty = Builder.getInt64Ty();
103 
104   // The length is either zero for a null pointer, or the computed value for an
105   // actual string. We need a join block for a phi that represents the final
106   // value.
107   //
108   //  Strictly speaking, the zero does not matter since
109   // __ockl_printf_append_string_n ignores the length if the pointer is null.
110   BasicBlock *Join = nullptr;
111   if (Prev->getTerminator()) {
112     Join = Prev->splitBasicBlock(Builder.GetInsertPoint(),
113                                  "strlen.join");
114     Prev->getTerminator()->eraseFromParent();
115   } else {
116     Join = BasicBlock::Create(M->getContext(), "strlen.join",
117                               Prev->getParent());
118   }
119   BasicBlock *While =
120       BasicBlock::Create(M->getContext(), "strlen.while",
121                          Prev->getParent(), Join);
122   BasicBlock *WhileDone = BasicBlock::Create(
123       M->getContext(), "strlen.while.done",
124       Prev->getParent(), Join);
125 
126   // Emit an early return for when the pointer is null.
127   Builder.SetInsertPoint(Prev);
128   auto CmpNull =
129       Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType()));
130   BranchInst::Create(Join, While, CmpNull, Prev);
131 
132   // Entry to the while loop.
133   Builder.SetInsertPoint(While);
134 
135   auto PtrPhi = Builder.CreatePHI(Str->getType(), 2);
136   PtrPhi->addIncoming(Str, Prev);
137   auto PtrNext = Builder.CreateGEP(PtrPhi, One);
138   PtrPhi->addIncoming(PtrNext, While);
139 
140   // Condition for the while loop.
141   auto Data = Builder.CreateLoad(PtrPhi);
142   auto Cmp = Builder.CreateICmpEQ(Data, CharZero);
143   Builder.CreateCondBr(Cmp, WhileDone, While);
144 
145   // Add one to the computed length.
146   Builder.SetInsertPoint(WhileDone, WhileDone->begin());
147   auto Begin = Builder.CreatePtrToInt(Str, Int64Ty);
148   auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty);
149   auto Len = Builder.CreateSub(End, Begin);
150   Len = Builder.CreateAdd(Len, One);
151 
152   // Final join.
153   BranchInst::Create(Join, WhileDone);
154   Builder.SetInsertPoint(Join, Join->begin());
155   auto LenPhi = Builder.CreatePHI(Len->getType(), 2);
156   LenPhi->addIncoming(Len, WhileDone);
157   LenPhi->addIncoming(Zero, Prev);
158 
159   return LenPhi;
160 }
161 
162 static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str,
163                                 Value *Length, bool isLast) {
164   auto Int64Ty = Builder.getInt64Ty();
165   auto CharPtrTy = Builder.getInt8PtrTy();
166   auto Int32Ty = Builder.getInt32Ty();
167   auto M = Builder.GetInsertBlock()->getModule();
168   auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty,
169                                    Int64Ty, CharPtrTy, Int64Ty, Int32Ty);
170   auto IsLastInt32 = Builder.getInt32(isLast);
171   return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32});
172 }
173 
174 static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg,
175                            bool IsLast) {
176   auto Length = getStrlenWithNull(Builder, Arg);
177   return callAppendStringN(Builder, Desc, Arg, Length, IsLast);
178 }
179 
180 static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
181                          bool SpecIsCString, bool IsLast) {
182   if (SpecIsCString && isCString(Arg)) {
183     return appendString(Builder, Desc, Arg, IsLast);
184   }
185   // If the format specifies a string but the argument is not, the frontend will
186   // have printed a warning. We just rely on undefined behaviour and send the
187   // argument anyway.
188   return appendArg(Builder, Desc, Arg, IsLast);
189 }
190 
191 // Scan the format string to locate all specifiers, and mark the ones that
192 // specify a string, i.e, the "%s" specifier with optional '*' characters.
193 static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) {
194   StringRef Str;
195   if (!getConstantStringInfo(Fmt, Str) || Str.empty())
196     return;
197 
198   static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn";
199   size_t SpecPos = 0;
200   // Skip the first argument, the format string.
201   unsigned ArgIdx = 1;
202 
203   while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) {
204     if (Str[SpecPos + 1] == '%') {
205       SpecPos += 2;
206       continue;
207     }
208     auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos);
209     if (SpecEnd == StringRef::npos)
210       return;
211     auto Spec = Str.slice(SpecPos, SpecEnd + 1);
212     ArgIdx += Spec.count('*');
213     if (Str[SpecEnd] == 's') {
214       BV.set(ArgIdx);
215     }
216     SpecPos = SpecEnd + 1;
217     ++ArgIdx;
218   }
219 }
220 
221 Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder,
222                                   ArrayRef<Value *> Args) {
223   auto NumOps = Args.size();
224   assert(NumOps >= 1);
225 
226   auto Fmt = Args[0];
227   SparseBitVector<8> SpecIsCString;
228   locateCStrings(SpecIsCString, Fmt);
229 
230   auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0));
231   Desc = appendString(Builder, Desc, Fmt, NumOps == 1);
232 
233   // FIXME: This invokes hostcall once for each argument. We can pack up to
234   // seven scalar printf arguments in a single hostcall. See the signature of
235   // callAppendArgs().
236   for (unsigned int i = 1; i != NumOps; ++i) {
237     bool IsLast = i == NumOps - 1;
238     bool IsCString = SpecIsCString.test(i);
239     Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast);
240   }
241 
242   return Builder.CreateTrunc(Desc, Builder.getInt32Ty());
243 }
244