1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/MachineValueType.h"
48 #include "llvm/CodeGen/TargetRegisterInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalValue.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/LLVMContext.h"
64 #include "llvm/IR/Module.h"
65 #include "llvm/IR/Operator.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/MC/MCExpr.h"
69 #include "llvm/MC/MCInst.h"
70 #include "llvm/MC/MCInstrDesc.h"
71 #include "llvm/MC/MCStreamer.h"
72 #include "llvm/MC/MCSymbol.h"
73 #include "llvm/MC/TargetRegistry.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/Endian.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/NativeFormatting.h"
79 #include "llvm/Support/Path.h"
80 #include "llvm/Support/raw_ostream.h"
81 #include "llvm/Target/TargetLoweringObjectFile.h"
82 #include "llvm/Target/TargetMachine.h"
83 #include "llvm/TargetParser/Triple.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
85 #include <cassert>
86 #include <cstdint>
87 #include <cstring>
88 #include <new>
89 #include <string>
90 #include <utility>
91 #include <vector>
92 
93 using namespace llvm;
94 
95 static cl::opt<bool>
96     LowerCtorDtor("nvptx-lower-global-ctor-dtor",
97                   cl::desc("Lower GPU ctor / dtors to globals on the device."),
98                   cl::init(false), cl::Hidden);
99 
100 #define DEPOTNAME "__local_depot"
101 
102 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
103 /// depends.
104 static void
105 DiscoverDependentGlobals(const Value *V,
106                          DenseSet<const GlobalVariable *> &Globals) {
107   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
108     Globals.insert(GV);
109   else {
110     if (const User *U = dyn_cast<User>(V)) {
111       for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
112         DiscoverDependentGlobals(U->getOperand(i), Globals);
113       }
114     }
115   }
116 }
117 
118 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
119 /// instances to be emitted, but only after any dependents have been added
120 /// first.s
121 static void
122 VisitGlobalVariableForEmission(const GlobalVariable *GV,
123                                SmallVectorImpl<const GlobalVariable *> &Order,
124                                DenseSet<const GlobalVariable *> &Visited,
125                                DenseSet<const GlobalVariable *> &Visiting) {
126   // Have we already visited this one?
127   if (Visited.count(GV))
128     return;
129 
130   // Do we have a circular dependency?
131   if (!Visiting.insert(GV).second)
132     report_fatal_error("Circular dependency found in global variable set");
133 
134   // Make sure we visit all dependents first
135   DenseSet<const GlobalVariable *> Others;
136   for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
137     DiscoverDependentGlobals(GV->getOperand(i), Others);
138 
139   for (const GlobalVariable *GV : Others)
140     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
141 
142   // Now we can visit ourself
143   Order.push_back(GV);
144   Visited.insert(GV);
145   Visiting.erase(GV);
146 }
147 
148 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
149   NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
150                                         getSubtargetInfo().getFeatureBits());
151 
152   MCInst Inst;
153   lowerToMCInst(MI, Inst);
154   EmitToStreamer(*OutStreamer, Inst);
155 }
156 
157 // Handle symbol backtracking for targets that do not support image handles
158 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
159                                            unsigned OpNo, MCOperand &MCOp) {
160   const MachineOperand &MO = MI->getOperand(OpNo);
161   const MCInstrDesc &MCID = MI->getDesc();
162 
163   if (MCID.TSFlags & NVPTXII::IsTexFlag) {
164     // This is a texture fetch, so operand 4 is a texref and operand 5 is
165     // a samplerref
166     if (OpNo == 4 && MO.isImm()) {
167       lowerImageHandleSymbol(MO.getImm(), MCOp);
168       return true;
169     }
170     if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
171       lowerImageHandleSymbol(MO.getImm(), MCOp);
172       return true;
173     }
174 
175     return false;
176   } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
177     unsigned VecSize =
178       1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
179 
180     // For a surface load of vector size N, the Nth operand will be the surfref
181     if (OpNo == VecSize && MO.isImm()) {
182       lowerImageHandleSymbol(MO.getImm(), MCOp);
183       return true;
184     }
185 
186     return false;
187   } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
188     // This is a surface store, so operand 0 is a surfref
189     if (OpNo == 0 && MO.isImm()) {
190       lowerImageHandleSymbol(MO.getImm(), MCOp);
191       return true;
192     }
193 
194     return false;
195   } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
196     // This is a query, so operand 1 is a surfref/texref
197     if (OpNo == 1 && MO.isImm()) {
198       lowerImageHandleSymbol(MO.getImm(), MCOp);
199       return true;
200     }
201 
202     return false;
203   }
204 
205   return false;
206 }
207 
208 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
209   // Ewwww
210   LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
211   NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
212   const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
213   const char *Sym = MFI->getImageHandleSymbol(Index);
214   StringRef SymName = nvTM.getStrPool().save(Sym);
215   MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
216 }
217 
218 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
219   OutMI.setOpcode(MI->getOpcode());
220   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
221   if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
222     const MachineOperand &MO = MI->getOperand(0);
223     OutMI.addOperand(GetSymbolRef(
224       OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
225     return;
226   }
227 
228   const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
229   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
230     const MachineOperand &MO = MI->getOperand(i);
231 
232     MCOperand MCOp;
233     if (!STI.hasImageHandles()) {
234       if (lowerImageHandleOperand(MI, i, MCOp)) {
235         OutMI.addOperand(MCOp);
236         continue;
237       }
238     }
239 
240     if (lowerOperand(MO, MCOp))
241       OutMI.addOperand(MCOp);
242   }
243 }
244 
245 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
246                                    MCOperand &MCOp) {
247   switch (MO.getType()) {
248   default: llvm_unreachable("unknown operand type");
249   case MachineOperand::MO_Register:
250     MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
251     break;
252   case MachineOperand::MO_Immediate:
253     MCOp = MCOperand::createImm(MO.getImm());
254     break;
255   case MachineOperand::MO_MachineBasicBlock:
256     MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
257         MO.getMBB()->getSymbol(), OutContext));
258     break;
259   case MachineOperand::MO_ExternalSymbol:
260     MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
261     break;
262   case MachineOperand::MO_GlobalAddress:
263     MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
264     break;
265   case MachineOperand::MO_FPImmediate: {
266     const ConstantFP *Cnt = MO.getFPImm();
267     const APFloat &Val = Cnt->getValueAPF();
268 
269     switch (Cnt->getType()->getTypeID()) {
270     default: report_fatal_error("Unsupported FP type"); break;
271     case Type::HalfTyID:
272       MCOp = MCOperand::createExpr(
273         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
274       break;
275     case Type::BFloatTyID:
276       MCOp = MCOperand::createExpr(
277           NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
278       break;
279     case Type::FloatTyID:
280       MCOp = MCOperand::createExpr(
281         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
282       break;
283     case Type::DoubleTyID:
284       MCOp = MCOperand::createExpr(
285         NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
286       break;
287     }
288     break;
289   }
290   }
291   return true;
292 }
293 
294 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
295   if (Register::isVirtualRegister(Reg)) {
296     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
297 
298     DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
299     unsigned RegNum = RegMap[Reg];
300 
301     // Encode the register class in the upper 4 bits
302     // Must be kept in sync with NVPTXInstPrinter::printRegName
303     unsigned Ret = 0;
304     if (RC == &NVPTX::Int1RegsRegClass) {
305       Ret = (1 << 28);
306     } else if (RC == &NVPTX::Int16RegsRegClass) {
307       Ret = (2 << 28);
308     } else if (RC == &NVPTX::Int32RegsRegClass) {
309       Ret = (3 << 28);
310     } else if (RC == &NVPTX::Int64RegsRegClass) {
311       Ret = (4 << 28);
312     } else if (RC == &NVPTX::Float32RegsRegClass) {
313       Ret = (5 << 28);
314     } else if (RC == &NVPTX::Float64RegsRegClass) {
315       Ret = (6 << 28);
316     } else {
317       report_fatal_error("Bad register class");
318     }
319 
320     // Insert the vreg number
321     Ret |= (RegNum & 0x0FFFFFFF);
322     return Ret;
323   } else {
324     // Some special-use registers are actually physical registers.
325     // Encode this as the register class ID of 0 and the real register ID.
326     return Reg & 0x0FFFFFFF;
327   }
328 }
329 
330 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
331   const MCExpr *Expr;
332   Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
333                                  OutContext);
334   return MCOperand::createExpr(Expr);
335 }
336 
337 static bool ShouldPassAsArray(Type *Ty) {
338   return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
339          Ty->isHalfTy() || Ty->isBFloatTy();
340 }
341 
342 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
343   const DataLayout &DL = getDataLayout();
344   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
345   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
346 
347   Type *Ty = F->getReturnType();
348 
349   bool isABI = (STI.getSmVersion() >= 20);
350 
351   if (Ty->getTypeID() == Type::VoidTyID)
352     return;
353   O << " (";
354 
355   if (isABI) {
356     if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
357         !ShouldPassAsArray(Ty)) {
358       unsigned size = 0;
359       if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
360         size = ITy->getBitWidth();
361       } else {
362         assert(Ty->isFloatingPointTy() && "Floating point type expected here");
363         size = Ty->getPrimitiveSizeInBits();
364       }
365       size = promoteScalarArgumentSize(size);
366       O << ".param .b" << size << " func_retval0";
367     } else if (isa<PointerType>(Ty)) {
368       O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
369         << " func_retval0";
370     } else if (ShouldPassAsArray(Ty)) {
371       unsigned totalsz = DL.getTypeAllocSize(Ty);
372       unsigned retAlignment = 0;
373       if (!getAlign(*F, 0, retAlignment))
374         retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
375       O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
376         << "]";
377     } else
378       llvm_unreachable("Unknown return type");
379   } else {
380     SmallVector<EVT, 16> vtparts;
381     ComputeValueVTs(*TLI, DL, Ty, vtparts);
382     unsigned idx = 0;
383     for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
384       unsigned elems = 1;
385       EVT elemtype = vtparts[i];
386       if (vtparts[i].isVector()) {
387         elems = vtparts[i].getVectorNumElements();
388         elemtype = vtparts[i].getVectorElementType();
389       }
390 
391       for (unsigned j = 0, je = elems; j != je; ++j) {
392         unsigned sz = elemtype.getSizeInBits();
393         if (elemtype.isInteger())
394           sz = promoteScalarArgumentSize(sz);
395         O << ".reg .b" << sz << " func_retval" << idx;
396         if (j < je - 1)
397           O << ", ";
398         ++idx;
399       }
400       if (i < e - 1)
401         O << ", ";
402     }
403   }
404   O << ") ";
405 }
406 
407 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
408                                         raw_ostream &O) {
409   const Function &F = MF.getFunction();
410   printReturnValStr(&F, O);
411 }
412 
413 // Return true if MBB is the header of a loop marked with
414 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
415 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
416     const MachineBasicBlock &MBB) const {
417   MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
418   // We insert .pragma "nounroll" only to the loop header.
419   if (!LI.isLoopHeader(&MBB))
420     return false;
421 
422   // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
423   // we iterate through each back edge of the loop with header MBB, and check
424   // whether its metadata contains llvm.loop.unroll.disable.
425   for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
426     if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
427       // Edges from other loops to MBB are not back edges.
428       continue;
429     }
430     if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
431       if (MDNode *LoopID =
432               PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
433         if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
434           return true;
435         if (MDNode *UnrollCountMD =
436                 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
437           if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
438                   ->isOne())
439             return true;
440         }
441       }
442     }
443   }
444   return false;
445 }
446 
447 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
448   AsmPrinter::emitBasicBlockStart(MBB);
449   if (isLoopHeaderOfNoUnroll(MBB))
450     OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
451 }
452 
453 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
454   SmallString<128> Str;
455   raw_svector_ostream O(Str);
456 
457   if (!GlobalsEmitted) {
458     emitGlobals(*MF->getFunction().getParent());
459     GlobalsEmitted = true;
460   }
461 
462   // Set up
463   MRI = &MF->getRegInfo();
464   F = &MF->getFunction();
465   emitLinkageDirective(F, O);
466   if (isKernelFunction(*F))
467     O << ".entry ";
468   else {
469     O << ".func ";
470     printReturnValStr(*MF, O);
471   }
472 
473   CurrentFnSym->print(O, MAI);
474 
475   emitFunctionParamList(F, O);
476   O << "\n";
477 
478   if (isKernelFunction(*F))
479     emitKernelFunctionDirectives(*F, O);
480 
481   if (shouldEmitPTXNoReturn(F, TM))
482     O << ".noreturn";
483 
484   OutStreamer->emitRawText(O.str());
485 
486   VRegMapping.clear();
487   // Emit open brace for function body.
488   OutStreamer->emitRawText(StringRef("{\n"));
489   setAndEmitFunctionVirtualRegisters(*MF);
490   // Emit initial .loc debug directive for correct relocation symbol data.
491   if (MMI && MMI->hasDebugInfo())
492     emitInitialRawDwarfLocDirective(*MF);
493 }
494 
495 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
496   bool Result = AsmPrinter::runOnMachineFunction(F);
497   // Emit closing brace for the body of function F.
498   // The closing brace must be emitted here because we need to emit additional
499   // debug labels/data after the last basic block.
500   // We need to emit the closing brace here because we don't have function that
501   // finished emission of the function body.
502   OutStreamer->emitRawText(StringRef("}\n"));
503   return Result;
504 }
505 
506 void NVPTXAsmPrinter::emitFunctionBodyStart() {
507   SmallString<128> Str;
508   raw_svector_ostream O(Str);
509   emitDemotedVars(&MF->getFunction(), O);
510   OutStreamer->emitRawText(O.str());
511 }
512 
513 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
514   VRegMapping.clear();
515 }
516 
517 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
518     SmallString<128> Str;
519     raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
520     return OutContext.getOrCreateSymbol(Str);
521 }
522 
523 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
524   Register RegNo = MI->getOperand(0).getReg();
525   if (RegNo.isVirtual()) {
526     OutStreamer->AddComment(Twine("implicit-def: ") +
527                             getVirtualRegisterName(RegNo));
528   } else {
529     const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
530     OutStreamer->AddComment(Twine("implicit-def: ") +
531                             STI.getRegisterInfo()->getName(RegNo));
532   }
533   OutStreamer->addBlankLine();
534 }
535 
536 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
537                                                    raw_ostream &O) const {
538   // If the NVVM IR has some of reqntid* specified, then output
539   // the reqntid directive, and set the unspecified ones to 1.
540   // If none of reqntid* is specified, don't output reqntid directive.
541   unsigned reqntidx, reqntidy, reqntidz;
542   bool specified = false;
543   if (!getReqNTIDx(F, reqntidx))
544     reqntidx = 1;
545   else
546     specified = true;
547   if (!getReqNTIDy(F, reqntidy))
548     reqntidy = 1;
549   else
550     specified = true;
551   if (!getReqNTIDz(F, reqntidz))
552     reqntidz = 1;
553   else
554     specified = true;
555 
556   if (specified)
557     O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
558       << "\n";
559 
560   // If the NVVM IR has some of maxntid* specified, then output
561   // the maxntid directive, and set the unspecified ones to 1.
562   // If none of maxntid* is specified, don't output maxntid directive.
563   unsigned maxntidx, maxntidy, maxntidz;
564   specified = false;
565   if (!getMaxNTIDx(F, maxntidx))
566     maxntidx = 1;
567   else
568     specified = true;
569   if (!getMaxNTIDy(F, maxntidy))
570     maxntidy = 1;
571   else
572     specified = true;
573   if (!getMaxNTIDz(F, maxntidz))
574     maxntidz = 1;
575   else
576     specified = true;
577 
578   if (specified)
579     O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
580       << "\n";
581 
582   unsigned mincta;
583   if (getMinCTASm(F, mincta))
584     O << ".minnctapersm " << mincta << "\n";
585 
586   unsigned maxnreg;
587   if (getMaxNReg(F, maxnreg))
588     O << ".maxnreg " << maxnreg << "\n";
589 }
590 
591 std::string
592 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
593   const TargetRegisterClass *RC = MRI->getRegClass(Reg);
594 
595   std::string Name;
596   raw_string_ostream NameStr(Name);
597 
598   VRegRCMap::const_iterator I = VRegMapping.find(RC);
599   assert(I != VRegMapping.end() && "Bad register class");
600   const DenseMap<unsigned, unsigned> &RegMap = I->second;
601 
602   VRegMap::const_iterator VI = RegMap.find(Reg);
603   assert(VI != RegMap.end() && "Bad virtual register");
604   unsigned MappedVR = VI->second;
605 
606   NameStr << getNVPTXRegClassStr(RC) << MappedVR;
607 
608   NameStr.flush();
609   return Name;
610 }
611 
612 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
613                                           raw_ostream &O) {
614   O << getVirtualRegisterName(vr);
615 }
616 
617 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
618   emitLinkageDirective(F, O);
619   if (isKernelFunction(*F))
620     O << ".entry ";
621   else
622     O << ".func ";
623   printReturnValStr(F, O);
624   getSymbol(F)->print(O, MAI);
625   O << "\n";
626   emitFunctionParamList(F, O);
627   O << "\n";
628   if (shouldEmitPTXNoReturn(F, TM))
629     O << ".noreturn";
630   O << ";\n";
631 }
632 
633 static bool usedInGlobalVarDef(const Constant *C) {
634   if (!C)
635     return false;
636 
637   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
638     return GV->getName() != "llvm.used";
639   }
640 
641   for (const User *U : C->users())
642     if (const Constant *C = dyn_cast<Constant>(U))
643       if (usedInGlobalVarDef(C))
644         return true;
645 
646   return false;
647 }
648 
649 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
650   if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
651     if (othergv->getName() == "llvm.used")
652       return true;
653   }
654 
655   if (const Instruction *instr = dyn_cast<Instruction>(U)) {
656     if (instr->getParent() && instr->getParent()->getParent()) {
657       const Function *curFunc = instr->getParent()->getParent();
658       if (oneFunc && (curFunc != oneFunc))
659         return false;
660       oneFunc = curFunc;
661       return true;
662     } else
663       return false;
664   }
665 
666   for (const User *UU : U->users())
667     if (!usedInOneFunc(UU, oneFunc))
668       return false;
669 
670   return true;
671 }
672 
673 /* Find out if a global variable can be demoted to local scope.
674  * Currently, this is valid for CUDA shared variables, which have local
675  * scope and global lifetime. So the conditions to check are :
676  * 1. Is the global variable in shared address space?
677  * 2. Does it have internal linkage?
678  * 3. Is the global variable referenced only in one function?
679  */
680 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
681   if (!gv->hasInternalLinkage())
682     return false;
683   PointerType *Pty = gv->getType();
684   if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
685     return false;
686 
687   const Function *oneFunc = nullptr;
688 
689   bool flag = usedInOneFunc(gv, oneFunc);
690   if (!flag)
691     return false;
692   if (!oneFunc)
693     return false;
694   f = oneFunc;
695   return true;
696 }
697 
698 static bool useFuncSeen(const Constant *C,
699                         DenseMap<const Function *, bool> &seenMap) {
700   for (const User *U : C->users()) {
701     if (const Constant *cu = dyn_cast<Constant>(U)) {
702       if (useFuncSeen(cu, seenMap))
703         return true;
704     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
705       const BasicBlock *bb = I->getParent();
706       if (!bb)
707         continue;
708       const Function *caller = bb->getParent();
709       if (!caller)
710         continue;
711       if (seenMap.contains(caller))
712         return true;
713     }
714   }
715   return false;
716 }
717 
718 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
719   DenseMap<const Function *, bool> seenMap;
720   for (const Function &F : M) {
721     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
722       emitDeclaration(&F, O);
723       continue;
724     }
725 
726     if (F.isDeclaration()) {
727       if (F.use_empty())
728         continue;
729       if (F.getIntrinsicID())
730         continue;
731       emitDeclaration(&F, O);
732       continue;
733     }
734     for (const User *U : F.users()) {
735       if (const Constant *C = dyn_cast<Constant>(U)) {
736         if (usedInGlobalVarDef(C)) {
737           // The use is in the initialization of a global variable
738           // that is a function pointer, so print a declaration
739           // for the original function
740           emitDeclaration(&F, O);
741           break;
742         }
743         // Emit a declaration of this function if the function that
744         // uses this constant expr has already been seen.
745         if (useFuncSeen(C, seenMap)) {
746           emitDeclaration(&F, O);
747           break;
748         }
749       }
750 
751       if (!isa<Instruction>(U))
752         continue;
753       const Instruction *instr = cast<Instruction>(U);
754       const BasicBlock *bb = instr->getParent();
755       if (!bb)
756         continue;
757       const Function *caller = bb->getParent();
758       if (!caller)
759         continue;
760 
761       // If a caller has already been seen, then the caller is
762       // appearing in the module before the callee. so print out
763       // a declaration for the callee.
764       if (seenMap.contains(caller)) {
765         emitDeclaration(&F, O);
766         break;
767       }
768     }
769     seenMap[&F] = true;
770   }
771 }
772 
773 static bool isEmptyXXStructor(GlobalVariable *GV) {
774   if (!GV) return true;
775   const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
776   if (!InitList) return true;  // Not an array; we don't know how to parse.
777   return InitList->getNumOperands() == 0;
778 }
779 
780 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
781   // Construct a default subtarget off of the TargetMachine defaults. The
782   // rest of NVPTX isn't friendly to change subtargets per function and
783   // so the default TargetMachine will have all of the options.
784   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
785   const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
786   SmallString<128> Str1;
787   raw_svector_ostream OS1(Str1);
788 
789   // Emit header before any dwarf directives are emitted below.
790   emitHeader(M, OS1, *STI);
791   OutStreamer->emitRawText(OS1.str());
792 }
793 
794 bool NVPTXAsmPrinter::doInitialization(Module &M) {
795   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
796   const NVPTXSubtarget &STI =
797       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
798   if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
799     report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
800 
801   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
802       !LowerCtorDtor) {
803     report_fatal_error(
804         "Module has a nontrivial global ctor, which NVPTX does not support.");
805     return true;  // error
806   }
807   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&
808       !LowerCtorDtor) {
809     report_fatal_error(
810         "Module has a nontrivial global dtor, which NVPTX does not support.");
811     return true;  // error
812   }
813 
814   // We need to call the parent's one explicitly.
815   bool Result = AsmPrinter::doInitialization(M);
816 
817   GlobalsEmitted = false;
818 
819   return Result;
820 }
821 
822 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
823   SmallString<128> Str2;
824   raw_svector_ostream OS2(Str2);
825 
826   emitDeclarations(M, OS2);
827 
828   // As ptxas does not support forward references of globals, we need to first
829   // sort the list of module-level globals in def-use order. We visit each
830   // global variable in order, and ensure that we emit it *after* its dependent
831   // globals. We use a little extra memory maintaining both a set and a list to
832   // have fast searches while maintaining a strict ordering.
833   SmallVector<const GlobalVariable *, 8> Globals;
834   DenseSet<const GlobalVariable *> GVVisited;
835   DenseSet<const GlobalVariable *> GVVisiting;
836 
837   // Visit each global variable, in order
838   for (const GlobalVariable &I : M.globals())
839     VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
840 
841   assert(GVVisited.size() == M.global_size() && "Missed a global variable");
842   assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
843 
844   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
845   const NVPTXSubtarget &STI =
846       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
847 
848   // Print out module-level global variables in proper order
849   for (unsigned i = 0, e = Globals.size(); i != e; ++i)
850     printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI);
851 
852   OS2 << '\n';
853 
854   OutStreamer->emitRawText(OS2.str());
855 }
856 
857 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
858   SmallString<128> Str;
859   raw_svector_ostream OS(Str);
860 
861   MCSymbol *Name = getSymbol(&GA);
862   const Function *F = dyn_cast<Function>(GA.getAliasee());
863   if (!F || isKernelFunction(*F))
864     report_fatal_error("NVPTX aliasee must be a non-kernel function");
865 
866   if (GA.hasLinkOnceLinkage() || GA.hasWeakLinkage() ||
867       GA.hasAvailableExternallyLinkage() || GA.hasCommonLinkage())
868     report_fatal_error("NVPTX aliasee must not be '.weak'");
869 
870   OS << "\n";
871   emitLinkageDirective(F, OS);
872   OS << ".func ";
873   printReturnValStr(F, OS);
874   OS << Name->getName();
875   emitFunctionParamList(F, OS);
876   if (shouldEmitPTXNoReturn(F, TM))
877     OS << "\n.noreturn";
878   OS << ";\n";
879 
880   OS << ".alias " << Name->getName() << ", " << F->getName() << ";\n";
881 
882   OutStreamer->emitRawText(OS.str());
883 }
884 
885 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
886                                  const NVPTXSubtarget &STI) {
887   O << "//\n";
888   O << "// Generated by LLVM NVPTX Back-End\n";
889   O << "//\n";
890   O << "\n";
891 
892   unsigned PTXVersion = STI.getPTXVersion();
893   O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
894 
895   O << ".target ";
896   O << STI.getTargetName();
897 
898   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
899   if (NTM.getDrvInterface() == NVPTX::NVCL)
900     O << ", texmode_independent";
901 
902   bool HasFullDebugInfo = false;
903   for (DICompileUnit *CU : M.debug_compile_units()) {
904     switch(CU->getEmissionKind()) {
905     case DICompileUnit::NoDebug:
906     case DICompileUnit::DebugDirectivesOnly:
907       break;
908     case DICompileUnit::LineTablesOnly:
909     case DICompileUnit::FullDebug:
910       HasFullDebugInfo = true;
911       break;
912     }
913     if (HasFullDebugInfo)
914       break;
915   }
916   if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
917     O << ", debug";
918 
919   O << "\n";
920 
921   O << ".address_size ";
922   if (NTM.is64Bit())
923     O << "64";
924   else
925     O << "32";
926   O << "\n";
927 
928   O << "\n";
929 }
930 
931 bool NVPTXAsmPrinter::doFinalization(Module &M) {
932   bool HasDebugInfo = MMI && MMI->hasDebugInfo();
933 
934   // If we did not emit any functions, then the global declarations have not
935   // yet been emitted.
936   if (!GlobalsEmitted) {
937     emitGlobals(M);
938     GlobalsEmitted = true;
939   }
940 
941   // If we have any aliases we emit them at the end.
942   SmallVector<GlobalAlias *> AliasesToRemove;
943   for (GlobalAlias &Alias : M.aliases()) {
944     emitGlobalAlias(M, Alias);
945     AliasesToRemove.push_back(&Alias);
946   }
947 
948   for (GlobalAlias *A : AliasesToRemove)
949     A->eraseFromParent();
950 
951   // call doFinalization
952   bool ret = AsmPrinter::doFinalization(M);
953 
954   clearAnnotationCache(&M);
955 
956   auto *TS =
957       static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
958   // Close the last emitted section
959   if (HasDebugInfo) {
960     TS->closeLastSection();
961     // Emit empty .debug_loc section for better support of the empty files.
962     OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
963   }
964 
965   // Output last DWARF .file directives, if any.
966   TS->outputDwarfFileDirectives();
967 
968   return ret;
969 }
970 
971 // This function emits appropriate linkage directives for
972 // functions and global variables.
973 //
974 // extern function declaration            -> .extern
975 // extern function definition             -> .visible
976 // external global variable with init     -> .visible
977 // external without init                  -> .extern
978 // appending                              -> not allowed, assert.
979 // for any linkage other than
980 // internal, private, linker_private,
981 // linker_private_weak, linker_private_weak_def_auto,
982 // we emit                                -> .weak.
983 
984 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
985                                            raw_ostream &O) {
986   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
987     if (V->hasExternalLinkage()) {
988       if (isa<GlobalVariable>(V)) {
989         const GlobalVariable *GVar = cast<GlobalVariable>(V);
990         if (GVar) {
991           if (GVar->hasInitializer())
992             O << ".visible ";
993           else
994             O << ".extern ";
995         }
996       } else if (V->isDeclaration())
997         O << ".extern ";
998       else
999         O << ".visible ";
1000     } else if (V->hasAppendingLinkage()) {
1001       std::string msg;
1002       msg.append("Error: ");
1003       msg.append("Symbol ");
1004       if (V->hasName())
1005         msg.append(std::string(V->getName()));
1006       msg.append("has unsupported appending linkage type");
1007       llvm_unreachable(msg.c_str());
1008     } else if (!V->hasInternalLinkage() &&
1009                !V->hasPrivateLinkage()) {
1010       O << ".weak ";
1011     }
1012   }
1013 }
1014 
1015 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1016                                          raw_ostream &O, bool processDemoted,
1017                                          const NVPTXSubtarget &STI) {
1018   // Skip meta data
1019   if (GVar->hasSection()) {
1020     if (GVar->getSection() == "llvm.metadata")
1021       return;
1022   }
1023 
1024   // Skip LLVM intrinsic global variables
1025   if (GVar->getName().startswith("llvm.") ||
1026       GVar->getName().startswith("nvvm."))
1027     return;
1028 
1029   const DataLayout &DL = getDataLayout();
1030 
1031   // GlobalVariables are always constant pointers themselves.
1032   PointerType *PTy = GVar->getType();
1033   Type *ETy = GVar->getValueType();
1034 
1035   if (GVar->hasExternalLinkage()) {
1036     if (GVar->hasInitializer())
1037       O << ".visible ";
1038     else
1039       O << ".extern ";
1040   } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1041              GVar->hasAvailableExternallyLinkage() ||
1042              GVar->hasCommonLinkage()) {
1043     O << ".weak ";
1044   }
1045 
1046   if (isTexture(*GVar)) {
1047     O << ".global .texref " << getTextureName(*GVar) << ";\n";
1048     return;
1049   }
1050 
1051   if (isSurface(*GVar)) {
1052     O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1053     return;
1054   }
1055 
1056   if (GVar->isDeclaration()) {
1057     // (extern) declarations, no definition or initializer
1058     // Currently the only known declaration is for an automatic __local
1059     // (.shared) promoted to global.
1060     emitPTXGlobalVariable(GVar, O, STI);
1061     O << ";\n";
1062     return;
1063   }
1064 
1065   if (isSampler(*GVar)) {
1066     O << ".global .samplerref " << getSamplerName(*GVar);
1067 
1068     const Constant *Initializer = nullptr;
1069     if (GVar->hasInitializer())
1070       Initializer = GVar->getInitializer();
1071     const ConstantInt *CI = nullptr;
1072     if (Initializer)
1073       CI = dyn_cast<ConstantInt>(Initializer);
1074     if (CI) {
1075       unsigned sample = CI->getZExtValue();
1076 
1077       O << " = { ";
1078 
1079       for (int i = 0,
1080                addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1081            i < 3; i++) {
1082         O << "addr_mode_" << i << " = ";
1083         switch (addr) {
1084         case 0:
1085           O << "wrap";
1086           break;
1087         case 1:
1088           O << "clamp_to_border";
1089           break;
1090         case 2:
1091           O << "clamp_to_edge";
1092           break;
1093         case 3:
1094           O << "wrap";
1095           break;
1096         case 4:
1097           O << "mirror";
1098           break;
1099         }
1100         O << ", ";
1101       }
1102       O << "filter_mode = ";
1103       switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1104       case 0:
1105         O << "nearest";
1106         break;
1107       case 1:
1108         O << "linear";
1109         break;
1110       case 2:
1111         llvm_unreachable("Anisotropic filtering is not supported");
1112       default:
1113         O << "nearest";
1114         break;
1115       }
1116       if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1117         O << ", force_unnormalized_coords = 1";
1118       }
1119       O << " }";
1120     }
1121 
1122     O << ";\n";
1123     return;
1124   }
1125 
1126   if (GVar->hasPrivateLinkage()) {
1127     if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1128       return;
1129 
1130     // FIXME - need better way (e.g. Metadata) to avoid generating this global
1131     if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1132       return;
1133     if (GVar->use_empty())
1134       return;
1135   }
1136 
1137   const Function *demotedFunc = nullptr;
1138   if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1139     O << "// " << GVar->getName() << " has been demoted\n";
1140     if (localDecls.find(demotedFunc) != localDecls.end())
1141       localDecls[demotedFunc].push_back(GVar);
1142     else {
1143       std::vector<const GlobalVariable *> temp;
1144       temp.push_back(GVar);
1145       localDecls[demotedFunc] = temp;
1146     }
1147     return;
1148   }
1149 
1150   O << ".";
1151   emitPTXAddressSpace(PTy->getAddressSpace(), O);
1152 
1153   if (isManaged(*GVar)) {
1154     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1155       report_fatal_error(
1156           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1157     }
1158     O << " .attribute(.managed)";
1159   }
1160 
1161   if (MaybeAlign A = GVar->getAlign())
1162     O << " .align " << A->value();
1163   else
1164     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1165 
1166   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1167       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1168     O << " .";
1169     // Special case: ABI requires that we use .u8 for predicates
1170     if (ETy->isIntegerTy(1))
1171       O << "u8";
1172     else
1173       O << getPTXFundamentalTypeStr(ETy, false);
1174     O << " ";
1175     getSymbol(GVar)->print(O, MAI);
1176 
1177     // Ptx allows variable initilization only for constant and global state
1178     // spaces.
1179     if (GVar->hasInitializer()) {
1180       if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1181           (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1182         const Constant *Initializer = GVar->getInitializer();
1183         // 'undef' is treated as there is no value specified.
1184         if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1185           O << " = ";
1186           printScalarConstant(Initializer, O);
1187         }
1188       } else {
1189         // The frontend adds zero-initializer to device and constant variables
1190         // that don't have an initial value, and UndefValue to shared
1191         // variables, so skip warning for this case.
1192         if (!GVar->getInitializer()->isNullValue() &&
1193             !isa<UndefValue>(GVar->getInitializer())) {
1194           report_fatal_error("initial value of '" + GVar->getName() +
1195                              "' is not allowed in addrspace(" +
1196                              Twine(PTy->getAddressSpace()) + ")");
1197         }
1198       }
1199     }
1200   } else {
1201     uint64_t ElementSize = 0;
1202 
1203     // Although PTX has direct support for struct type and array type and
1204     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1205     // targets that support these high level field accesses. Structs, arrays
1206     // and vectors are lowered into arrays of bytes.
1207     switch (ETy->getTypeID()) {
1208     case Type::IntegerTyID: // Integers larger than 64 bits
1209     case Type::StructTyID:
1210     case Type::ArrayTyID:
1211     case Type::FixedVectorTyID:
1212       ElementSize = DL.getTypeStoreSize(ETy);
1213       // Ptx allows variable initilization only for constant and
1214       // global state spaces.
1215       if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1216            (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1217           GVar->hasInitializer()) {
1218         const Constant *Initializer = GVar->getInitializer();
1219         if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1220           AggBuffer aggBuffer(ElementSize, *this);
1221           bufferAggregateConstant(Initializer, &aggBuffer);
1222           if (aggBuffer.numSymbols()) {
1223             unsigned int ptrSize = MAI->getCodePointerSize();
1224             if (ElementSize % ptrSize ||
1225                 !aggBuffer.allSymbolsAligned(ptrSize)) {
1226               // Print in bytes and use the mask() operator for pointers.
1227               if (!STI.hasMaskOperator())
1228                 report_fatal_error(
1229                     "initialized packed aggregate with pointers '" +
1230                     GVar->getName() +
1231                     "' requires at least PTX ISA version 7.1");
1232               O << " .u8 ";
1233               getSymbol(GVar)->print(O, MAI);
1234               O << "[" << ElementSize << "] = {";
1235               aggBuffer.printBytes(O);
1236               O << "}";
1237             } else {
1238               O << " .u" << ptrSize * 8 << " ";
1239               getSymbol(GVar)->print(O, MAI);
1240               O << "[" << ElementSize / ptrSize << "] = {";
1241               aggBuffer.printWords(O);
1242               O << "}";
1243             }
1244           } else {
1245             O << " .b8 ";
1246             getSymbol(GVar)->print(O, MAI);
1247             O << "[" << ElementSize << "] = {";
1248             aggBuffer.printBytes(O);
1249             O << "}";
1250           }
1251         } else {
1252           O << " .b8 ";
1253           getSymbol(GVar)->print(O, MAI);
1254           if (ElementSize) {
1255             O << "[";
1256             O << ElementSize;
1257             O << "]";
1258           }
1259         }
1260       } else {
1261         O << " .b8 ";
1262         getSymbol(GVar)->print(O, MAI);
1263         if (ElementSize) {
1264           O << "[";
1265           O << ElementSize;
1266           O << "]";
1267         }
1268       }
1269       break;
1270     default:
1271       llvm_unreachable("type not supported yet");
1272     }
1273   }
1274   O << ";\n";
1275 }
1276 
1277 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1278   const Value *v = Symbols[nSym];
1279   const Value *v0 = SymbolsBeforeStripping[nSym];
1280   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1281     MCSymbol *Name = AP.getSymbol(GVar);
1282     PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1283     // Is v0 a generic pointer?
1284     bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1285     if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1286       os << "generic(";
1287       Name->print(os, AP.MAI);
1288       os << ")";
1289     } else {
1290       Name->print(os, AP.MAI);
1291     }
1292   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1293     const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1294     AP.printMCExpr(*Expr, os);
1295   } else
1296     llvm_unreachable("symbol type unknown");
1297 }
1298 
1299 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1300   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1301   symbolPosInBuffer.push_back(size);
1302   unsigned int nSym = 0;
1303   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1304   for (unsigned int pos = 0; pos < size;) {
1305     if (pos)
1306       os << ", ";
1307     if (pos != nextSymbolPos) {
1308       os << (unsigned int)buffer[pos];
1309       ++pos;
1310       continue;
1311     }
1312     // Generate a per-byte mask() operator for the symbol, which looks like:
1313     //   .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1314     // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1315     std::string symText;
1316     llvm::raw_string_ostream oss(symText);
1317     printSymbol(nSym, oss);
1318     for (unsigned i = 0; i < ptrSize; ++i) {
1319       if (i)
1320         os << ", ";
1321       llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1322       os << "(" << symText << ")";
1323     }
1324     pos += ptrSize;
1325     nextSymbolPos = symbolPosInBuffer[++nSym];
1326     assert(nextSymbolPos >= pos);
1327   }
1328 }
1329 
1330 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1331   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1332   symbolPosInBuffer.push_back(size);
1333   unsigned int nSym = 0;
1334   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1335   assert(nextSymbolPos % ptrSize == 0);
1336   for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1337     if (pos)
1338       os << ", ";
1339     if (pos == nextSymbolPos) {
1340       printSymbol(nSym, os);
1341       nextSymbolPos = symbolPosInBuffer[++nSym];
1342       assert(nextSymbolPos % ptrSize == 0);
1343       assert(nextSymbolPos >= pos + ptrSize);
1344     } else if (ptrSize == 4)
1345       os << support::endian::read32le(&buffer[pos]);
1346     else
1347       os << support::endian::read64le(&buffer[pos]);
1348   }
1349 }
1350 
1351 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1352   if (localDecls.find(f) == localDecls.end())
1353     return;
1354 
1355   std::vector<const GlobalVariable *> &gvars = localDecls[f];
1356 
1357   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1358   const NVPTXSubtarget &STI =
1359       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1360 
1361   for (const GlobalVariable *GV : gvars) {
1362     O << "\t// demoted variable\n\t";
1363     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1364   }
1365 }
1366 
1367 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1368                                           raw_ostream &O) const {
1369   switch (AddressSpace) {
1370   case ADDRESS_SPACE_LOCAL:
1371     O << "local";
1372     break;
1373   case ADDRESS_SPACE_GLOBAL:
1374     O << "global";
1375     break;
1376   case ADDRESS_SPACE_CONST:
1377     O << "const";
1378     break;
1379   case ADDRESS_SPACE_SHARED:
1380     O << "shared";
1381     break;
1382   default:
1383     report_fatal_error("Bad address space found while emitting PTX: " +
1384                        llvm::Twine(AddressSpace));
1385     break;
1386   }
1387 }
1388 
1389 std::string
1390 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1391   switch (Ty->getTypeID()) {
1392   case Type::IntegerTyID: {
1393     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1394     if (NumBits == 1)
1395       return "pred";
1396     else if (NumBits <= 64) {
1397       std::string name = "u";
1398       return name + utostr(NumBits);
1399     } else {
1400       llvm_unreachable("Integer too large");
1401       break;
1402     }
1403     break;
1404   }
1405   case Type::BFloatTyID:
1406   case Type::HalfTyID:
1407     // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1408     // PTX assembly.
1409     return "b16";
1410   case Type::FloatTyID:
1411     return "f32";
1412   case Type::DoubleTyID:
1413     return "f64";
1414   case Type::PointerTyID: {
1415     unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1416     assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1417 
1418     if (PtrSize == 64)
1419       if (useB4PTR)
1420         return "b64";
1421       else
1422         return "u64";
1423     else if (useB4PTR)
1424       return "b32";
1425     else
1426       return "u32";
1427   }
1428   default:
1429     break;
1430   }
1431   llvm_unreachable("unexpected type");
1432 }
1433 
1434 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1435                                             raw_ostream &O,
1436                                             const NVPTXSubtarget &STI) {
1437   const DataLayout &DL = getDataLayout();
1438 
1439   // GlobalVariables are always constant pointers themselves.
1440   Type *ETy = GVar->getValueType();
1441 
1442   O << ".";
1443   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1444   if (isManaged(*GVar)) {
1445     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1446       report_fatal_error(
1447           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1448     }
1449     O << " .attribute(.managed)";
1450   }
1451   if (MaybeAlign A = GVar->getAlign())
1452     O << " .align " << A->value();
1453   else
1454     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1455 
1456   // Special case for i128
1457   if (ETy->isIntegerTy(128)) {
1458     O << " .b8 ";
1459     getSymbol(GVar)->print(O, MAI);
1460     O << "[16]";
1461     return;
1462   }
1463 
1464   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1465     O << " .";
1466     O << getPTXFundamentalTypeStr(ETy);
1467     O << " ";
1468     getSymbol(GVar)->print(O, MAI);
1469     return;
1470   }
1471 
1472   int64_t ElementSize = 0;
1473 
1474   // Although PTX has direct support for struct type and array type and LLVM IR
1475   // is very similar to PTX, the LLVM CodeGen does not support for targets that
1476   // support these high level field accesses. Structs and arrays are lowered
1477   // into arrays of bytes.
1478   switch (ETy->getTypeID()) {
1479   case Type::StructTyID:
1480   case Type::ArrayTyID:
1481   case Type::FixedVectorTyID:
1482     ElementSize = DL.getTypeStoreSize(ETy);
1483     O << " .b8 ";
1484     getSymbol(GVar)->print(O, MAI);
1485     O << "[";
1486     if (ElementSize) {
1487       O << ElementSize;
1488     }
1489     O << "]";
1490     break;
1491   default:
1492     llvm_unreachable("type not supported yet");
1493   }
1494 }
1495 
1496 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1497   const DataLayout &DL = getDataLayout();
1498   const AttributeList &PAL = F->getAttributes();
1499   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1500   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1501 
1502   Function::const_arg_iterator I, E;
1503   unsigned paramIndex = 0;
1504   bool first = true;
1505   bool isKernelFunc = isKernelFunction(*F);
1506   bool isABI = (STI.getSmVersion() >= 20);
1507   bool hasImageHandles = STI.hasImageHandles();
1508 
1509   if (F->arg_empty() && !F->isVarArg()) {
1510     O << "()";
1511     return;
1512   }
1513 
1514   O << "(\n";
1515 
1516   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1517     Type *Ty = I->getType();
1518 
1519     if (!first)
1520       O << ",\n";
1521 
1522     first = false;
1523 
1524     // Handle image/sampler parameters
1525     if (isKernelFunction(*F)) {
1526       if (isSampler(*I) || isImage(*I)) {
1527         if (isImage(*I)) {
1528           std::string sname = std::string(I->getName());
1529           if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1530             if (hasImageHandles)
1531               O << "\t.param .u64 .ptr .surfref ";
1532             else
1533               O << "\t.param .surfref ";
1534             O << TLI->getParamName(F, paramIndex);
1535           }
1536           else { // Default image is read_only
1537             if (hasImageHandles)
1538               O << "\t.param .u64 .ptr .texref ";
1539             else
1540               O << "\t.param .texref ";
1541             O << TLI->getParamName(F, paramIndex);
1542           }
1543         } else {
1544           if (hasImageHandles)
1545             O << "\t.param .u64 .ptr .samplerref ";
1546           else
1547             O << "\t.param .samplerref ";
1548           O << TLI->getParamName(F, paramIndex);
1549         }
1550         continue;
1551       }
1552     }
1553 
1554     auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1555                                     paramIndex](Type *Ty) -> Align {
1556       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1557       MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1558       return std::max(TypeAlign, ParamAlign.valueOrOne());
1559     };
1560 
1561     if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1562       if (ShouldPassAsArray(Ty)) {
1563         // Just print .param .align <a> .b8 .param[size];
1564         // <a>  = optimal alignment for the element type; always multiple of
1565         //        PAL.getParamAlignment
1566         // size = typeallocsize of element type
1567         Align OptimalAlign = getOptimalAlignForParam(Ty);
1568 
1569         O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1570         O << TLI->getParamName(F, paramIndex);
1571         O << "[" << DL.getTypeAllocSize(Ty) << "]";
1572 
1573         continue;
1574       }
1575       // Just a scalar
1576       auto *PTy = dyn_cast<PointerType>(Ty);
1577       unsigned PTySizeInBits = 0;
1578       if (PTy) {
1579         PTySizeInBits =
1580             TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1581         assert(PTySizeInBits && "Invalid pointer size");
1582       }
1583 
1584       if (isKernelFunc) {
1585         if (PTy) {
1586           // Special handling for pointer arguments to kernel
1587           O << "\t.param .u" << PTySizeInBits << " ";
1588 
1589           if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1590               NVPTX::CUDA) {
1591             int addrSpace = PTy->getAddressSpace();
1592             switch (addrSpace) {
1593             default:
1594               O << ".ptr ";
1595               break;
1596             case ADDRESS_SPACE_CONST:
1597               O << ".ptr .const ";
1598               break;
1599             case ADDRESS_SPACE_SHARED:
1600               O << ".ptr .shared ";
1601               break;
1602             case ADDRESS_SPACE_GLOBAL:
1603               O << ".ptr .global ";
1604               break;
1605             }
1606             Align ParamAlign = I->getParamAlign().valueOrOne();
1607             O << ".align " << ParamAlign.value() << " ";
1608           }
1609           O << TLI->getParamName(F, paramIndex);
1610           continue;
1611         }
1612 
1613         // non-pointer scalar to kernel func
1614         O << "\t.param .";
1615         // Special case: predicate operands become .u8 types
1616         if (Ty->isIntegerTy(1))
1617           O << "u8";
1618         else
1619           O << getPTXFundamentalTypeStr(Ty);
1620         O << " ";
1621         O << TLI->getParamName(F, paramIndex);
1622         continue;
1623       }
1624       // Non-kernel function, just print .param .b<size> for ABI
1625       // and .reg .b<size> for non-ABI
1626       unsigned sz = 0;
1627       if (isa<IntegerType>(Ty)) {
1628         sz = cast<IntegerType>(Ty)->getBitWidth();
1629         sz = promoteScalarArgumentSize(sz);
1630       } else if (PTy) {
1631         assert(PTySizeInBits && "Invalid pointer size");
1632         sz = PTySizeInBits;
1633       } else
1634         sz = Ty->getPrimitiveSizeInBits();
1635       if (isABI)
1636         O << "\t.param .b" << sz << " ";
1637       else
1638         O << "\t.reg .b" << sz << " ";
1639       O << TLI->getParamName(F, paramIndex);
1640       continue;
1641     }
1642 
1643     // param has byVal attribute.
1644     Type *ETy = PAL.getParamByValType(paramIndex);
1645     assert(ETy && "Param should have byval type");
1646 
1647     if (isABI || isKernelFunc) {
1648       // Just print .param .align <a> .b8 .param[size];
1649       // <a>  = optimal alignment for the element type; always multiple of
1650       //        PAL.getParamAlignment
1651       // size = typeallocsize of element type
1652       Align OptimalAlign =
1653           isKernelFunc
1654               ? getOptimalAlignForParam(ETy)
1655               : TLI->getFunctionByValParamAlign(
1656                     F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1657 
1658       unsigned sz = DL.getTypeAllocSize(ETy);
1659       O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1660       O << TLI->getParamName(F, paramIndex);
1661       O << "[" << sz << "]";
1662       continue;
1663     } else {
1664       // Split the ETy into constituent parts and
1665       // print .param .b<size> <name> for each part.
1666       // Further, if a part is vector, print the above for
1667       // each vector element.
1668       SmallVector<EVT, 16> vtparts;
1669       ComputeValueVTs(*TLI, DL, ETy, vtparts);
1670       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1671         unsigned elems = 1;
1672         EVT elemtype = vtparts[i];
1673         if (vtparts[i].isVector()) {
1674           elems = vtparts[i].getVectorNumElements();
1675           elemtype = vtparts[i].getVectorElementType();
1676         }
1677 
1678         for (unsigned j = 0, je = elems; j != je; ++j) {
1679           unsigned sz = elemtype.getSizeInBits();
1680           if (elemtype.isInteger())
1681             sz = promoteScalarArgumentSize(sz);
1682           O << "\t.reg .b" << sz << " ";
1683           O << TLI->getParamName(F, paramIndex);
1684           if (j < je - 1)
1685             O << ",\n";
1686           ++paramIndex;
1687         }
1688         if (i < e - 1)
1689           O << ",\n";
1690       }
1691       --paramIndex;
1692       continue;
1693     }
1694   }
1695 
1696   if (F->isVarArg()) {
1697     if (!first)
1698       O << ",\n";
1699     O << "\t.param .align " << STI.getMaxRequiredAlignment();
1700     O << " .b8 ";
1701     O << TLI->getParamName(F, /* vararg */ -1) << "[]";
1702   }
1703 
1704   O << "\n)";
1705 }
1706 
1707 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1708     const MachineFunction &MF) {
1709   SmallString<128> Str;
1710   raw_svector_ostream O(Str);
1711 
1712   // Map the global virtual register number to a register class specific
1713   // virtual register number starting from 1 with that class.
1714   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1715   //unsigned numRegClasses = TRI->getNumRegClasses();
1716 
1717   // Emit the Fake Stack Object
1718   const MachineFrameInfo &MFI = MF.getFrameInfo();
1719   int NumBytes = (int) MFI.getStackSize();
1720   if (NumBytes) {
1721     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1722       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1723     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1724       O << "\t.reg .b64 \t%SP;\n";
1725       O << "\t.reg .b64 \t%SPL;\n";
1726     } else {
1727       O << "\t.reg .b32 \t%SP;\n";
1728       O << "\t.reg .b32 \t%SPL;\n";
1729     }
1730   }
1731 
1732   // Go through all virtual registers to establish the mapping between the
1733   // global virtual
1734   // register number and the per class virtual register number.
1735   // We use the per class virtual register number in the ptx output.
1736   unsigned int numVRs = MRI->getNumVirtRegs();
1737   for (unsigned i = 0; i < numVRs; i++) {
1738     Register vr = Register::index2VirtReg(i);
1739     const TargetRegisterClass *RC = MRI->getRegClass(vr);
1740     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1741     int n = regmap.size();
1742     regmap.insert(std::make_pair(vr, n + 1));
1743   }
1744 
1745   // Emit register declarations
1746   // @TODO: Extract out the real register usage
1747   // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1748   // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1749   // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1750   // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1751   // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1752   // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1753   // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1754 
1755   // Emit declaration of the virtual registers or 'physical' registers for
1756   // each register class
1757   for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1758     const TargetRegisterClass *RC = TRI->getRegClass(i);
1759     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1760     std::string rcname = getNVPTXRegClassName(RC);
1761     std::string rcStr = getNVPTXRegClassStr(RC);
1762     int n = regmap.size();
1763 
1764     // Only declare those registers that may be used.
1765     if (n) {
1766        O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1767          << ">;\n";
1768     }
1769   }
1770 
1771   OutStreamer->emitRawText(O.str());
1772 }
1773 
1774 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1775   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1776   bool ignored;
1777   unsigned int numHex;
1778   const char *lead;
1779 
1780   if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1781     numHex = 8;
1782     lead = "0f";
1783     APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1784   } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1785     numHex = 16;
1786     lead = "0d";
1787     APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1788   } else
1789     llvm_unreachable("unsupported fp type");
1790 
1791   APInt API = APF.bitcastToAPInt();
1792   O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1793 }
1794 
1795 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1796   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1797     O << CI->getValue();
1798     return;
1799   }
1800   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1801     printFPConstant(CFP, O);
1802     return;
1803   }
1804   if (isa<ConstantPointerNull>(CPV)) {
1805     O << "0";
1806     return;
1807   }
1808   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1809     bool IsNonGenericPointer = false;
1810     if (GVar->getType()->getAddressSpace() != 0) {
1811       IsNonGenericPointer = true;
1812     }
1813     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1814       O << "generic(";
1815       getSymbol(GVar)->print(O, MAI);
1816       O << ")";
1817     } else {
1818       getSymbol(GVar)->print(O, MAI);
1819     }
1820     return;
1821   }
1822   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1823     const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1824     printMCExpr(*E, O);
1825     return;
1826   }
1827   llvm_unreachable("Not scalar type found in printScalarConstant()");
1828 }
1829 
1830 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1831                                    AggBuffer *AggBuffer) {
1832   const DataLayout &DL = getDataLayout();
1833   int AllocSize = DL.getTypeAllocSize(CPV->getType());
1834   if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1835     // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1836     // only the space allocated by CPV.
1837     AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1838     return;
1839   }
1840 
1841   // Helper for filling AggBuffer with APInts.
1842   auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1843     size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1844     SmallVector<unsigned char, 16> Buf(NumBytes);
1845     for (unsigned I = 0; I < NumBytes; ++I) {
1846       Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1847     }
1848     AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1849   };
1850 
1851   switch (CPV->getType()->getTypeID()) {
1852   case Type::IntegerTyID:
1853     if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1854       AddIntToBuffer(CI->getValue());
1855       break;
1856     }
1857     if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1858       if (const auto *CI =
1859               dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1860         AddIntToBuffer(CI->getValue());
1861         break;
1862       }
1863       if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1864         Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1865         AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1866         AggBuffer->addZeros(AllocSize);
1867         break;
1868       }
1869     }
1870     llvm_unreachable("unsupported integer const type");
1871     break;
1872 
1873   case Type::HalfTyID:
1874   case Type::BFloatTyID:
1875   case Type::FloatTyID:
1876   case Type::DoubleTyID:
1877     AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1878     break;
1879 
1880   case Type::PointerTyID: {
1881     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1882       AggBuffer->addSymbol(GVar, GVar);
1883     } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1884       const Value *v = Cexpr->stripPointerCasts();
1885       AggBuffer->addSymbol(v, Cexpr);
1886     }
1887     AggBuffer->addZeros(AllocSize);
1888     break;
1889   }
1890 
1891   case Type::ArrayTyID:
1892   case Type::FixedVectorTyID:
1893   case Type::StructTyID: {
1894     if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1895       bufferAggregateConstant(CPV, AggBuffer);
1896       if (Bytes > AllocSize)
1897         AggBuffer->addZeros(Bytes - AllocSize);
1898     } else if (isa<ConstantAggregateZero>(CPV))
1899       AggBuffer->addZeros(Bytes);
1900     else
1901       llvm_unreachable("Unexpected Constant type");
1902     break;
1903   }
1904 
1905   default:
1906     llvm_unreachable("unsupported type");
1907   }
1908 }
1909 
1910 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1911                                               AggBuffer *aggBuffer) {
1912   const DataLayout &DL = getDataLayout();
1913   int Bytes;
1914 
1915   // Integers of arbitrary width
1916   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1917     APInt Val = CI->getValue();
1918     for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1919       uint8_t Byte = Val.getLoBits(8).getZExtValue();
1920       aggBuffer->addBytes(&Byte, 1, 1);
1921       Val.lshrInPlace(8);
1922     }
1923     return;
1924   }
1925 
1926   // Old constants
1927   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1928     if (CPV->getNumOperands())
1929       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1930         bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1931     return;
1932   }
1933 
1934   if (const ConstantDataSequential *CDS =
1935           dyn_cast<ConstantDataSequential>(CPV)) {
1936     if (CDS->getNumElements())
1937       for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1938         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1939                      aggBuffer);
1940     return;
1941   }
1942 
1943   if (isa<ConstantStruct>(CPV)) {
1944     if (CPV->getNumOperands()) {
1945       StructType *ST = cast<StructType>(CPV->getType());
1946       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1947         if (i == (e - 1))
1948           Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1949                   DL.getTypeAllocSize(ST) -
1950                   DL.getStructLayout(ST)->getElementOffset(i);
1951         else
1952           Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1953                   DL.getStructLayout(ST)->getElementOffset(i);
1954         bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1955       }
1956     }
1957     return;
1958   }
1959   llvm_unreachable("unsupported constant type in printAggregateConstant()");
1960 }
1961 
1962 /// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
1963 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1964 /// expressions that are representable in PTX and create
1965 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1966 const MCExpr *
1967 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1968   MCContext &Ctx = OutContext;
1969 
1970   if (CV->isNullValue() || isa<UndefValue>(CV))
1971     return MCConstantExpr::create(0, Ctx);
1972 
1973   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1974     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1975 
1976   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1977     const MCSymbolRefExpr *Expr =
1978       MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1979     if (ProcessingGeneric) {
1980       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1981     } else {
1982       return Expr;
1983     }
1984   }
1985 
1986   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1987   if (!CE) {
1988     llvm_unreachable("Unknown constant value to lower!");
1989   }
1990 
1991   switch (CE->getOpcode()) {
1992   default: {
1993     // If the code isn't optimized, there may be outstanding folding
1994     // opportunities. Attempt to fold the expression using DataLayout as a
1995     // last resort before giving up.
1996     Constant *C = ConstantFoldConstant(CE, getDataLayout());
1997     if (C != CE)
1998       return lowerConstantForGV(C, ProcessingGeneric);
1999 
2000     // Otherwise report the problem to the user.
2001     std::string S;
2002     raw_string_ostream OS(S);
2003     OS << "Unsupported expression in static initializer: ";
2004     CE->printAsOperand(OS, /*PrintType=*/false,
2005                    !MF ? nullptr : MF->getFunction().getParent());
2006     report_fatal_error(Twine(OS.str()));
2007   }
2008 
2009   case Instruction::AddrSpaceCast: {
2010     // Strip the addrspacecast and pass along the operand
2011     PointerType *DstTy = cast<PointerType>(CE->getType());
2012     if (DstTy->getAddressSpace() == 0) {
2013       return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
2014     }
2015     std::string S;
2016     raw_string_ostream OS(S);
2017     OS << "Unsupported expression in static initializer: ";
2018     CE->printAsOperand(OS, /*PrintType=*/ false,
2019                        !MF ? nullptr : MF->getFunction().getParent());
2020     report_fatal_error(Twine(OS.str()));
2021   }
2022 
2023   case Instruction::GetElementPtr: {
2024     const DataLayout &DL = getDataLayout();
2025 
2026     // Generate a symbolic expression for the byte address
2027     APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2028     cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2029 
2030     const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2031                                             ProcessingGeneric);
2032     if (!OffsetAI)
2033       return Base;
2034 
2035     int64_t Offset = OffsetAI.getSExtValue();
2036     return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2037                                    Ctx);
2038   }
2039 
2040   case Instruction::Trunc:
2041     // We emit the value and depend on the assembler to truncate the generated
2042     // expression properly.  This is important for differences between
2043     // blockaddress labels.  Since the two labels are in the same function, it
2044     // is reasonable to treat their delta as a 32-bit value.
2045     [[fallthrough]];
2046   case Instruction::BitCast:
2047     return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2048 
2049   case Instruction::IntToPtr: {
2050     const DataLayout &DL = getDataLayout();
2051 
2052     // Handle casts to pointers by changing them into casts to the appropriate
2053     // integer type.  This promotes constant folding and simplifies this code.
2054     Constant *Op = CE->getOperand(0);
2055     Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2056                                       false/*ZExt*/);
2057     return lowerConstantForGV(Op, ProcessingGeneric);
2058   }
2059 
2060   case Instruction::PtrToInt: {
2061     const DataLayout &DL = getDataLayout();
2062 
2063     // Support only foldable casts to/from pointers that can be eliminated by
2064     // changing the pointer to the appropriately sized integer type.
2065     Constant *Op = CE->getOperand(0);
2066     Type *Ty = CE->getType();
2067 
2068     const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2069 
2070     // We can emit the pointer value into this slot if the slot is an
2071     // integer slot equal to the size of the pointer.
2072     if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2073       return OpExpr;
2074 
2075     // Otherwise the pointer is smaller than the resultant integer, mask off
2076     // the high bits so we are sure to get a proper truncation if the input is
2077     // a constant expr.
2078     unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2079     const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2080     return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2081   }
2082 
2083   // The MC library also has a right-shift operator, but it isn't consistently
2084   // signed or unsigned between different targets.
2085   case Instruction::Add: {
2086     const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2087     const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2088     switch (CE->getOpcode()) {
2089     default: llvm_unreachable("Unknown binary operator constant cast expr");
2090     case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2091     }
2092   }
2093   }
2094 }
2095 
2096 // Copy of MCExpr::print customized for NVPTX
2097 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2098   switch (Expr.getKind()) {
2099   case MCExpr::Target:
2100     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2101   case MCExpr::Constant:
2102     OS << cast<MCConstantExpr>(Expr).getValue();
2103     return;
2104 
2105   case MCExpr::SymbolRef: {
2106     const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2107     const MCSymbol &Sym = SRE.getSymbol();
2108     Sym.print(OS, MAI);
2109     return;
2110   }
2111 
2112   case MCExpr::Unary: {
2113     const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2114     switch (UE.getOpcode()) {
2115     case MCUnaryExpr::LNot:  OS << '!'; break;
2116     case MCUnaryExpr::Minus: OS << '-'; break;
2117     case MCUnaryExpr::Not:   OS << '~'; break;
2118     case MCUnaryExpr::Plus:  OS << '+'; break;
2119     }
2120     printMCExpr(*UE.getSubExpr(), OS);
2121     return;
2122   }
2123 
2124   case MCExpr::Binary: {
2125     const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2126 
2127     // Only print parens around the LHS if it is non-trivial.
2128     if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2129         isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2130       printMCExpr(*BE.getLHS(), OS);
2131     } else {
2132       OS << '(';
2133       printMCExpr(*BE.getLHS(), OS);
2134       OS<< ')';
2135     }
2136 
2137     switch (BE.getOpcode()) {
2138     case MCBinaryExpr::Add:
2139       // Print "X-42" instead of "X+-42".
2140       if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2141         if (RHSC->getValue() < 0) {
2142           OS << RHSC->getValue();
2143           return;
2144         }
2145       }
2146 
2147       OS <<  '+';
2148       break;
2149     default: llvm_unreachable("Unhandled binary operator");
2150     }
2151 
2152     // Only print parens around the LHS if it is non-trivial.
2153     if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2154       printMCExpr(*BE.getRHS(), OS);
2155     } else {
2156       OS << '(';
2157       printMCExpr(*BE.getRHS(), OS);
2158       OS << ')';
2159     }
2160     return;
2161   }
2162   }
2163 
2164   llvm_unreachable("Invalid expression kind!");
2165 }
2166 
2167 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2168 ///
2169 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2170                                       const char *ExtraCode, raw_ostream &O) {
2171   if (ExtraCode && ExtraCode[0]) {
2172     if (ExtraCode[1] != 0)
2173       return true; // Unknown modifier.
2174 
2175     switch (ExtraCode[0]) {
2176     default:
2177       // See if this is a generic print operand
2178       return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2179     case 'r':
2180       break;
2181     }
2182   }
2183 
2184   printOperand(MI, OpNo, O);
2185 
2186   return false;
2187 }
2188 
2189 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2190                                             unsigned OpNo,
2191                                             const char *ExtraCode,
2192                                             raw_ostream &O) {
2193   if (ExtraCode && ExtraCode[0])
2194     return true; // Unknown modifier
2195 
2196   O << '[';
2197   printMemOperand(MI, OpNo, O);
2198   O << ']';
2199 
2200   return false;
2201 }
2202 
2203 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
2204                                    raw_ostream &O) {
2205   const MachineOperand &MO = MI->getOperand(opNum);
2206   switch (MO.getType()) {
2207   case MachineOperand::MO_Register:
2208     if (MO.getReg().isPhysical()) {
2209       if (MO.getReg() == NVPTX::VRDepot)
2210         O << DEPOTNAME << getFunctionNumber();
2211       else
2212         O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2213     } else {
2214       emitVirtualRegister(MO.getReg(), O);
2215     }
2216     break;
2217 
2218   case MachineOperand::MO_Immediate:
2219     O << MO.getImm();
2220     break;
2221 
2222   case MachineOperand::MO_FPImmediate:
2223     printFPConstant(MO.getFPImm(), O);
2224     break;
2225 
2226   case MachineOperand::MO_GlobalAddress:
2227     PrintSymbolOperand(MO, O);
2228     break;
2229 
2230   case MachineOperand::MO_MachineBasicBlock:
2231     MO.getMBB()->getSymbol()->print(O, MAI);
2232     break;
2233 
2234   default:
2235     llvm_unreachable("Operand type not supported.");
2236   }
2237 }
2238 
2239 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
2240                                       raw_ostream &O, const char *Modifier) {
2241   printOperand(MI, opNum, O);
2242 
2243   if (Modifier && strcmp(Modifier, "add") == 0) {
2244     O << ", ";
2245     printOperand(MI, opNum + 1, O);
2246   } else {
2247     if (MI->getOperand(opNum + 1).isImm() &&
2248         MI->getOperand(opNum + 1).getImm() == 0)
2249       return; // don't print ',0' or '+0'
2250     O << "+";
2251     printOperand(MI, opNum + 1, O);
2252   }
2253 }
2254 
2255 // Force static initialization.
2256 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2257   RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2258   RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2259 }
2260