1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/MachineValueType.h"
48 #include "llvm/CodeGen/TargetRegisterInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalValue.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/LLVMContext.h"
64 #include "llvm/IR/Module.h"
65 #include "llvm/IR/Operator.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/MC/MCExpr.h"
69 #include "llvm/MC/MCInst.h"
70 #include "llvm/MC/MCInstrDesc.h"
71 #include "llvm/MC/MCStreamer.h"
72 #include "llvm/MC/MCSymbol.h"
73 #include "llvm/MC/TargetRegistry.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/Endian.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/NativeFormatting.h"
79 #include "llvm/Support/Path.h"
80 #include "llvm/Support/raw_ostream.h"
81 #include "llvm/Target/TargetLoweringObjectFile.h"
82 #include "llvm/Target/TargetMachine.h"
83 #include "llvm/TargetParser/Triple.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
85 #include <cassert>
86 #include <cstdint>
87 #include <cstring>
88 #include <new>
89 #include <string>
90 #include <utility>
91 #include <vector>
92 
93 using namespace llvm;
94 
95 static cl::opt<bool>
96     LowerCtorDtor("nvptx-lower-global-ctor-dtor",
97                   cl::desc("Lower GPU ctor / dtors to globals on the device."),
98                   cl::init(false), cl::Hidden);
99 
100 #define DEPOTNAME "__local_depot"
101 
102 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
103 /// depends.
104 static void
105 DiscoverDependentGlobals(const Value *V,
106                          DenseSet<const GlobalVariable *> &Globals) {
107   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
108     Globals.insert(GV);
109   else {
110     if (const User *U = dyn_cast<User>(V)) {
111       for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
112         DiscoverDependentGlobals(U->getOperand(i), Globals);
113       }
114     }
115   }
116 }
117 
118 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
119 /// instances to be emitted, but only after any dependents have been added
120 /// first.s
121 static void
122 VisitGlobalVariableForEmission(const GlobalVariable *GV,
123                                SmallVectorImpl<const GlobalVariable *> &Order,
124                                DenseSet<const GlobalVariable *> &Visited,
125                                DenseSet<const GlobalVariable *> &Visiting) {
126   // Have we already visited this one?
127   if (Visited.count(GV))
128     return;
129 
130   // Do we have a circular dependency?
131   if (!Visiting.insert(GV).second)
132     report_fatal_error("Circular dependency found in global variable set");
133 
134   // Make sure we visit all dependents first
135   DenseSet<const GlobalVariable *> Others;
136   for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
137     DiscoverDependentGlobals(GV->getOperand(i), Others);
138 
139   for (const GlobalVariable *GV : Others)
140     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
141 
142   // Now we can visit ourself
143   Order.push_back(GV);
144   Visited.insert(GV);
145   Visiting.erase(GV);
146 }
147 
148 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
149   NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
150                                         getSubtargetInfo().getFeatureBits());
151 
152   MCInst Inst;
153   lowerToMCInst(MI, Inst);
154   EmitToStreamer(*OutStreamer, Inst);
155 }
156 
157 // Handle symbol backtracking for targets that do not support image handles
158 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
159                                            unsigned OpNo, MCOperand &MCOp) {
160   const MachineOperand &MO = MI->getOperand(OpNo);
161   const MCInstrDesc &MCID = MI->getDesc();
162 
163   if (MCID.TSFlags & NVPTXII::IsTexFlag) {
164     // This is a texture fetch, so operand 4 is a texref and operand 5 is
165     // a samplerref
166     if (OpNo == 4 && MO.isImm()) {
167       lowerImageHandleSymbol(MO.getImm(), MCOp);
168       return true;
169     }
170     if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
171       lowerImageHandleSymbol(MO.getImm(), MCOp);
172       return true;
173     }
174 
175     return false;
176   } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
177     unsigned VecSize =
178       1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
179 
180     // For a surface load of vector size N, the Nth operand will be the surfref
181     if (OpNo == VecSize && MO.isImm()) {
182       lowerImageHandleSymbol(MO.getImm(), MCOp);
183       return true;
184     }
185 
186     return false;
187   } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
188     // This is a surface store, so operand 0 is a surfref
189     if (OpNo == 0 && MO.isImm()) {
190       lowerImageHandleSymbol(MO.getImm(), MCOp);
191       return true;
192     }
193 
194     return false;
195   } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
196     // This is a query, so operand 1 is a surfref/texref
197     if (OpNo == 1 && MO.isImm()) {
198       lowerImageHandleSymbol(MO.getImm(), MCOp);
199       return true;
200     }
201 
202     return false;
203   }
204 
205   return false;
206 }
207 
208 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
209   // Ewwww
210   LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
211   NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
212   const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
213   const char *Sym = MFI->getImageHandleSymbol(Index);
214   StringRef SymName = nvTM.getStrPool().save(Sym);
215   MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
216 }
217 
218 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
219   OutMI.setOpcode(MI->getOpcode());
220   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
221   if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
222     const MachineOperand &MO = MI->getOperand(0);
223     OutMI.addOperand(GetSymbolRef(
224       OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
225     return;
226   }
227 
228   const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
229   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
230     const MachineOperand &MO = MI->getOperand(i);
231 
232     MCOperand MCOp;
233     if (!STI.hasImageHandles()) {
234       if (lowerImageHandleOperand(MI, i, MCOp)) {
235         OutMI.addOperand(MCOp);
236         continue;
237       }
238     }
239 
240     if (lowerOperand(MO, MCOp))
241       OutMI.addOperand(MCOp);
242   }
243 }
244 
245 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
246                                    MCOperand &MCOp) {
247   switch (MO.getType()) {
248   default: llvm_unreachable("unknown operand type");
249   case MachineOperand::MO_Register:
250     MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
251     break;
252   case MachineOperand::MO_Immediate:
253     MCOp = MCOperand::createImm(MO.getImm());
254     break;
255   case MachineOperand::MO_MachineBasicBlock:
256     MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
257         MO.getMBB()->getSymbol(), OutContext));
258     break;
259   case MachineOperand::MO_ExternalSymbol:
260     MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
261     break;
262   case MachineOperand::MO_GlobalAddress:
263     MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
264     break;
265   case MachineOperand::MO_FPImmediate: {
266     const ConstantFP *Cnt = MO.getFPImm();
267     const APFloat &Val = Cnt->getValueAPF();
268 
269     switch (Cnt->getType()->getTypeID()) {
270     default: report_fatal_error("Unsupported FP type"); break;
271     case Type::HalfTyID:
272       MCOp = MCOperand::createExpr(
273         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
274       break;
275     case Type::BFloatTyID:
276       MCOp = MCOperand::createExpr(
277           NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
278       break;
279     case Type::FloatTyID:
280       MCOp = MCOperand::createExpr(
281         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
282       break;
283     case Type::DoubleTyID:
284       MCOp = MCOperand::createExpr(
285         NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
286       break;
287     }
288     break;
289   }
290   }
291   return true;
292 }
293 
294 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
295   if (Register::isVirtualRegister(Reg)) {
296     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
297 
298     DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
299     unsigned RegNum = RegMap[Reg];
300 
301     // Encode the register class in the upper 4 bits
302     // Must be kept in sync with NVPTXInstPrinter::printRegName
303     unsigned Ret = 0;
304     if (RC == &NVPTX::Int1RegsRegClass) {
305       Ret = (1 << 28);
306     } else if (RC == &NVPTX::Int16RegsRegClass) {
307       Ret = (2 << 28);
308     } else if (RC == &NVPTX::Int32RegsRegClass) {
309       Ret = (3 << 28);
310     } else if (RC == &NVPTX::Int64RegsRegClass) {
311       Ret = (4 << 28);
312     } else if (RC == &NVPTX::Float32RegsRegClass) {
313       Ret = (5 << 28);
314     } else if (RC == &NVPTX::Float64RegsRegClass) {
315       Ret = (6 << 28);
316     } else {
317       report_fatal_error("Bad register class");
318     }
319 
320     // Insert the vreg number
321     Ret |= (RegNum & 0x0FFFFFFF);
322     return Ret;
323   } else {
324     // Some special-use registers are actually physical registers.
325     // Encode this as the register class ID of 0 and the real register ID.
326     return Reg & 0x0FFFFFFF;
327   }
328 }
329 
330 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
331   const MCExpr *Expr;
332   Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
333                                  OutContext);
334   return MCOperand::createExpr(Expr);
335 }
336 
337 static bool ShouldPassAsArray(Type *Ty) {
338   return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
339          Ty->isHalfTy() || Ty->isBFloatTy();
340 }
341 
342 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
343   const DataLayout &DL = getDataLayout();
344   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
345   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
346 
347   Type *Ty = F->getReturnType();
348 
349   bool isABI = (STI.getSmVersion() >= 20);
350 
351   if (Ty->getTypeID() == Type::VoidTyID)
352     return;
353   O << " (";
354 
355   if (isABI) {
356     if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
357         !ShouldPassAsArray(Ty)) {
358       unsigned size = 0;
359       if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
360         size = ITy->getBitWidth();
361       } else {
362         assert(Ty->isFloatingPointTy() && "Floating point type expected here");
363         size = Ty->getPrimitiveSizeInBits();
364       }
365       size = promoteScalarArgumentSize(size);
366       O << ".param .b" << size << " func_retval0";
367     } else if (isa<PointerType>(Ty)) {
368       O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
369         << " func_retval0";
370     } else if (ShouldPassAsArray(Ty)) {
371       unsigned totalsz = DL.getTypeAllocSize(Ty);
372       unsigned retAlignment = 0;
373       if (!getAlign(*F, 0, retAlignment))
374         retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
375       O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
376         << "]";
377     } else
378       llvm_unreachable("Unknown return type");
379   } else {
380     SmallVector<EVT, 16> vtparts;
381     ComputeValueVTs(*TLI, DL, Ty, vtparts);
382     unsigned idx = 0;
383     for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
384       unsigned elems = 1;
385       EVT elemtype = vtparts[i];
386       if (vtparts[i].isVector()) {
387         elems = vtparts[i].getVectorNumElements();
388         elemtype = vtparts[i].getVectorElementType();
389       }
390 
391       for (unsigned j = 0, je = elems; j != je; ++j) {
392         unsigned sz = elemtype.getSizeInBits();
393         if (elemtype.isInteger())
394           sz = promoteScalarArgumentSize(sz);
395         O << ".reg .b" << sz << " func_retval" << idx;
396         if (j < je - 1)
397           O << ", ";
398         ++idx;
399       }
400       if (i < e - 1)
401         O << ", ";
402     }
403   }
404   O << ") ";
405 }
406 
407 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
408                                         raw_ostream &O) {
409   const Function &F = MF.getFunction();
410   printReturnValStr(&F, O);
411 }
412 
413 // Return true if MBB is the header of a loop marked with
414 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
415 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
416     const MachineBasicBlock &MBB) const {
417   MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
418   // We insert .pragma "nounroll" only to the loop header.
419   if (!LI.isLoopHeader(&MBB))
420     return false;
421 
422   // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
423   // we iterate through each back edge of the loop with header MBB, and check
424   // whether its metadata contains llvm.loop.unroll.disable.
425   for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
426     if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
427       // Edges from other loops to MBB are not back edges.
428       continue;
429     }
430     if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
431       if (MDNode *LoopID =
432               PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
433         if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
434           return true;
435         if (MDNode *UnrollCountMD =
436                 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
437           if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
438                   ->isOne())
439             return true;
440         }
441       }
442     }
443   }
444   return false;
445 }
446 
447 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
448   AsmPrinter::emitBasicBlockStart(MBB);
449   if (isLoopHeaderOfNoUnroll(MBB))
450     OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
451 }
452 
453 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
454   SmallString<128> Str;
455   raw_svector_ostream O(Str);
456 
457   if (!GlobalsEmitted) {
458     emitGlobals(*MF->getFunction().getParent());
459     GlobalsEmitted = true;
460   }
461 
462   // Set up
463   MRI = &MF->getRegInfo();
464   F = &MF->getFunction();
465   emitLinkageDirective(F, O);
466   if (isKernelFunction(*F))
467     O << ".entry ";
468   else {
469     O << ".func ";
470     printReturnValStr(*MF, O);
471   }
472 
473   CurrentFnSym->print(O, MAI);
474 
475   emitFunctionParamList(F, O);
476   O << "\n";
477 
478   if (isKernelFunction(*F))
479     emitKernelFunctionDirectives(*F, O);
480 
481   if (shouldEmitPTXNoReturn(F, TM))
482     O << ".noreturn";
483 
484   OutStreamer->emitRawText(O.str());
485 
486   VRegMapping.clear();
487   // Emit open brace for function body.
488   OutStreamer->emitRawText(StringRef("{\n"));
489   setAndEmitFunctionVirtualRegisters(*MF);
490   // Emit initial .loc debug directive for correct relocation symbol data.
491   if (MMI && MMI->hasDebugInfo())
492     emitInitialRawDwarfLocDirective(*MF);
493 }
494 
495 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
496   bool Result = AsmPrinter::runOnMachineFunction(F);
497   // Emit closing brace for the body of function F.
498   // The closing brace must be emitted here because we need to emit additional
499   // debug labels/data after the last basic block.
500   // We need to emit the closing brace here because we don't have function that
501   // finished emission of the function body.
502   OutStreamer->emitRawText(StringRef("}\n"));
503   return Result;
504 }
505 
506 void NVPTXAsmPrinter::emitFunctionBodyStart() {
507   SmallString<128> Str;
508   raw_svector_ostream O(Str);
509   emitDemotedVars(&MF->getFunction(), O);
510   OutStreamer->emitRawText(O.str());
511 }
512 
513 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
514   VRegMapping.clear();
515 }
516 
517 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
518     SmallString<128> Str;
519     raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
520     return OutContext.getOrCreateSymbol(Str);
521 }
522 
523 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
524   Register RegNo = MI->getOperand(0).getReg();
525   if (RegNo.isVirtual()) {
526     OutStreamer->AddComment(Twine("implicit-def: ") +
527                             getVirtualRegisterName(RegNo));
528   } else {
529     const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
530     OutStreamer->AddComment(Twine("implicit-def: ") +
531                             STI.getRegisterInfo()->getName(RegNo));
532   }
533   OutStreamer->addBlankLine();
534 }
535 
536 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
537                                                    raw_ostream &O) const {
538   // If the NVVM IR has some of reqntid* specified, then output
539   // the reqntid directive, and set the unspecified ones to 1.
540   // If none of Reqntid* is specified, don't output reqntid directive.
541   unsigned Reqntidx, Reqntidy, Reqntidz;
542   Reqntidx = Reqntidy = Reqntidz = 1;
543   bool ReqSpecified = false;
544   ReqSpecified |= getReqNTIDx(F, Reqntidx);
545   ReqSpecified |= getReqNTIDy(F, Reqntidy);
546   ReqSpecified |= getReqNTIDz(F, Reqntidz);
547 
548   if (ReqSpecified)
549     O << ".reqntid " << Reqntidx << ", " << Reqntidy << ", " << Reqntidz
550       << "\n";
551 
552   // If the NVVM IR has some of maxntid* specified, then output
553   // the maxntid directive, and set the unspecified ones to 1.
554   // If none of maxntid* is specified, don't output maxntid directive.
555   unsigned Maxntidx, Maxntidy, Maxntidz;
556   Maxntidx = Maxntidy = Maxntidz = 1;
557   bool MaxSpecified = false;
558   MaxSpecified |= getMaxNTIDx(F, Maxntidx);
559   MaxSpecified |= getMaxNTIDy(F, Maxntidy);
560   MaxSpecified |= getMaxNTIDz(F, Maxntidz);
561 
562   if (MaxSpecified)
563     O << ".maxntid " << Maxntidx << ", " << Maxntidy << ", " << Maxntidz
564       << "\n";
565 
566   unsigned Mincta = 0;
567   if (getMinCTASm(F, Mincta))
568     O << ".minnctapersm " << Mincta << "\n";
569 
570   unsigned Maxnreg = 0;
571   if (getMaxNReg(F, Maxnreg))
572     O << ".maxnreg " << Maxnreg << "\n";
573 
574   // .maxclusterrank directive requires SM_90 or higher, make sure that we
575   // filter it out for lower SM versions, as it causes a hard ptxas crash.
576   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
577   const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
578   unsigned Maxclusterrank = 0;
579   if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
580     O << ".maxclusterrank " << Maxclusterrank << "\n";
581 }
582 
583 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
584   const TargetRegisterClass *RC = MRI->getRegClass(Reg);
585 
586   std::string Name;
587   raw_string_ostream NameStr(Name);
588 
589   VRegRCMap::const_iterator I = VRegMapping.find(RC);
590   assert(I != VRegMapping.end() && "Bad register class");
591   const DenseMap<unsigned, unsigned> &RegMap = I->second;
592 
593   VRegMap::const_iterator VI = RegMap.find(Reg);
594   assert(VI != RegMap.end() && "Bad virtual register");
595   unsigned MappedVR = VI->second;
596 
597   NameStr << getNVPTXRegClassStr(RC) << MappedVR;
598 
599   NameStr.flush();
600   return Name;
601 }
602 
603 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
604                                           raw_ostream &O) {
605   O << getVirtualRegisterName(vr);
606 }
607 
608 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
609   emitLinkageDirective(F, O);
610   if (isKernelFunction(*F))
611     O << ".entry ";
612   else
613     O << ".func ";
614   printReturnValStr(F, O);
615   getSymbol(F)->print(O, MAI);
616   O << "\n";
617   emitFunctionParamList(F, O);
618   O << "\n";
619   if (shouldEmitPTXNoReturn(F, TM))
620     O << ".noreturn";
621   O << ";\n";
622 }
623 
624 static bool usedInGlobalVarDef(const Constant *C) {
625   if (!C)
626     return false;
627 
628   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
629     return GV->getName() != "llvm.used";
630   }
631 
632   for (const User *U : C->users())
633     if (const Constant *C = dyn_cast<Constant>(U))
634       if (usedInGlobalVarDef(C))
635         return true;
636 
637   return false;
638 }
639 
640 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
641   if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
642     if (othergv->getName() == "llvm.used")
643       return true;
644   }
645 
646   if (const Instruction *instr = dyn_cast<Instruction>(U)) {
647     if (instr->getParent() && instr->getParent()->getParent()) {
648       const Function *curFunc = instr->getParent()->getParent();
649       if (oneFunc && (curFunc != oneFunc))
650         return false;
651       oneFunc = curFunc;
652       return true;
653     } else
654       return false;
655   }
656 
657   for (const User *UU : U->users())
658     if (!usedInOneFunc(UU, oneFunc))
659       return false;
660 
661   return true;
662 }
663 
664 /* Find out if a global variable can be demoted to local scope.
665  * Currently, this is valid for CUDA shared variables, which have local
666  * scope and global lifetime. So the conditions to check are :
667  * 1. Is the global variable in shared address space?
668  * 2. Does it have local linkage?
669  * 3. Is the global variable referenced only in one function?
670  */
671 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
672   if (!gv->hasLocalLinkage())
673     return false;
674   PointerType *Pty = gv->getType();
675   if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
676     return false;
677 
678   const Function *oneFunc = nullptr;
679 
680   bool flag = usedInOneFunc(gv, oneFunc);
681   if (!flag)
682     return false;
683   if (!oneFunc)
684     return false;
685   f = oneFunc;
686   return true;
687 }
688 
689 static bool useFuncSeen(const Constant *C,
690                         DenseMap<const Function *, bool> &seenMap) {
691   for (const User *U : C->users()) {
692     if (const Constant *cu = dyn_cast<Constant>(U)) {
693       if (useFuncSeen(cu, seenMap))
694         return true;
695     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
696       const BasicBlock *bb = I->getParent();
697       if (!bb)
698         continue;
699       const Function *caller = bb->getParent();
700       if (!caller)
701         continue;
702       if (seenMap.contains(caller))
703         return true;
704     }
705   }
706   return false;
707 }
708 
709 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
710   DenseMap<const Function *, bool> seenMap;
711   for (const Function &F : M) {
712     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
713       emitDeclaration(&F, O);
714       continue;
715     }
716 
717     if (F.isDeclaration()) {
718       if (F.use_empty())
719         continue;
720       if (F.getIntrinsicID())
721         continue;
722       emitDeclaration(&F, O);
723       continue;
724     }
725     for (const User *U : F.users()) {
726       if (const Constant *C = dyn_cast<Constant>(U)) {
727         if (usedInGlobalVarDef(C)) {
728           // The use is in the initialization of a global variable
729           // that is a function pointer, so print a declaration
730           // for the original function
731           emitDeclaration(&F, O);
732           break;
733         }
734         // Emit a declaration of this function if the function that
735         // uses this constant expr has already been seen.
736         if (useFuncSeen(C, seenMap)) {
737           emitDeclaration(&F, O);
738           break;
739         }
740       }
741 
742       if (!isa<Instruction>(U))
743         continue;
744       const Instruction *instr = cast<Instruction>(U);
745       const BasicBlock *bb = instr->getParent();
746       if (!bb)
747         continue;
748       const Function *caller = bb->getParent();
749       if (!caller)
750         continue;
751 
752       // If a caller has already been seen, then the caller is
753       // appearing in the module before the callee. so print out
754       // a declaration for the callee.
755       if (seenMap.contains(caller)) {
756         emitDeclaration(&F, O);
757         break;
758       }
759     }
760     seenMap[&F] = true;
761   }
762 }
763 
764 static bool isEmptyXXStructor(GlobalVariable *GV) {
765   if (!GV) return true;
766   const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
767   if (!InitList) return true;  // Not an array; we don't know how to parse.
768   return InitList->getNumOperands() == 0;
769 }
770 
771 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
772   // Construct a default subtarget off of the TargetMachine defaults. The
773   // rest of NVPTX isn't friendly to change subtargets per function and
774   // so the default TargetMachine will have all of the options.
775   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
776   const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
777   SmallString<128> Str1;
778   raw_svector_ostream OS1(Str1);
779 
780   // Emit header before any dwarf directives are emitted below.
781   emitHeader(M, OS1, *STI);
782   OutStreamer->emitRawText(OS1.str());
783 }
784 
785 bool NVPTXAsmPrinter::doInitialization(Module &M) {
786   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
787   const NVPTXSubtarget &STI =
788       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
789   if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
790     report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
791 
792   // OpenMP supports NVPTX global constructors and destructors.
793   bool IsOpenMP = M.getModuleFlag("openmp") != nullptr;
794 
795   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
796       !LowerCtorDtor && !IsOpenMP) {
797     report_fatal_error(
798         "Module has a nontrivial global ctor, which NVPTX does not support.");
799     return true;  // error
800   }
801   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&
802       !LowerCtorDtor && !IsOpenMP) {
803     report_fatal_error(
804         "Module has a nontrivial global dtor, which NVPTX does not support.");
805     return true;  // error
806   }
807 
808   // We need to call the parent's one explicitly.
809   bool Result = AsmPrinter::doInitialization(M);
810 
811   GlobalsEmitted = false;
812 
813   return Result;
814 }
815 
816 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
817   SmallString<128> Str2;
818   raw_svector_ostream OS2(Str2);
819 
820   emitDeclarations(M, OS2);
821 
822   // As ptxas does not support forward references of globals, we need to first
823   // sort the list of module-level globals in def-use order. We visit each
824   // global variable in order, and ensure that we emit it *after* its dependent
825   // globals. We use a little extra memory maintaining both a set and a list to
826   // have fast searches while maintaining a strict ordering.
827   SmallVector<const GlobalVariable *, 8> Globals;
828   DenseSet<const GlobalVariable *> GVVisited;
829   DenseSet<const GlobalVariable *> GVVisiting;
830 
831   // Visit each global variable, in order
832   for (const GlobalVariable &I : M.globals())
833     VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
834 
835   assert(GVVisited.size() == M.global_size() && "Missed a global variable");
836   assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
837 
838   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
839   const NVPTXSubtarget &STI =
840       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
841 
842   // Print out module-level global variables in proper order
843   for (unsigned i = 0, e = Globals.size(); i != e; ++i)
844     printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI);
845 
846   OS2 << '\n';
847 
848   OutStreamer->emitRawText(OS2.str());
849 }
850 
851 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
852   SmallString<128> Str;
853   raw_svector_ostream OS(Str);
854 
855   MCSymbol *Name = getSymbol(&GA);
856   const Function *F = dyn_cast<Function>(GA.getAliasee());
857   if (!F || isKernelFunction(*F))
858     report_fatal_error("NVPTX aliasee must be a non-kernel function");
859 
860   if (GA.hasLinkOnceLinkage() || GA.hasWeakLinkage() ||
861       GA.hasAvailableExternallyLinkage() || GA.hasCommonLinkage())
862     report_fatal_error("NVPTX aliasee must not be '.weak'");
863 
864   OS << "\n";
865   emitLinkageDirective(F, OS);
866   OS << ".func ";
867   printReturnValStr(F, OS);
868   OS << Name->getName();
869   emitFunctionParamList(F, OS);
870   if (shouldEmitPTXNoReturn(F, TM))
871     OS << "\n.noreturn";
872   OS << ";\n";
873 
874   OS << ".alias " << Name->getName() << ", " << F->getName() << ";\n";
875 
876   OutStreamer->emitRawText(OS.str());
877 }
878 
879 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
880                                  const NVPTXSubtarget &STI) {
881   O << "//\n";
882   O << "// Generated by LLVM NVPTX Back-End\n";
883   O << "//\n";
884   O << "\n";
885 
886   unsigned PTXVersion = STI.getPTXVersion();
887   O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
888 
889   O << ".target ";
890   O << STI.getTargetName();
891 
892   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
893   if (NTM.getDrvInterface() == NVPTX::NVCL)
894     O << ", texmode_independent";
895 
896   bool HasFullDebugInfo = false;
897   for (DICompileUnit *CU : M.debug_compile_units()) {
898     switch(CU->getEmissionKind()) {
899     case DICompileUnit::NoDebug:
900     case DICompileUnit::DebugDirectivesOnly:
901       break;
902     case DICompileUnit::LineTablesOnly:
903     case DICompileUnit::FullDebug:
904       HasFullDebugInfo = true;
905       break;
906     }
907     if (HasFullDebugInfo)
908       break;
909   }
910   if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
911     O << ", debug";
912 
913   O << "\n";
914 
915   O << ".address_size ";
916   if (NTM.is64Bit())
917     O << "64";
918   else
919     O << "32";
920   O << "\n";
921 
922   O << "\n";
923 }
924 
925 bool NVPTXAsmPrinter::doFinalization(Module &M) {
926   bool HasDebugInfo = MMI && MMI->hasDebugInfo();
927 
928   // If we did not emit any functions, then the global declarations have not
929   // yet been emitted.
930   if (!GlobalsEmitted) {
931     emitGlobals(M);
932     GlobalsEmitted = true;
933   }
934 
935   // If we have any aliases we emit them at the end.
936   SmallVector<GlobalAlias *> AliasesToRemove;
937   for (GlobalAlias &Alias : M.aliases()) {
938     emitGlobalAlias(M, Alias);
939     AliasesToRemove.push_back(&Alias);
940   }
941 
942   for (GlobalAlias *A : AliasesToRemove)
943     A->eraseFromParent();
944 
945   // call doFinalization
946   bool ret = AsmPrinter::doFinalization(M);
947 
948   clearAnnotationCache(&M);
949 
950   auto *TS =
951       static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
952   // Close the last emitted section
953   if (HasDebugInfo) {
954     TS->closeLastSection();
955     // Emit empty .debug_loc section for better support of the empty files.
956     OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
957   }
958 
959   // Output last DWARF .file directives, if any.
960   TS->outputDwarfFileDirectives();
961 
962   return ret;
963 }
964 
965 // This function emits appropriate linkage directives for
966 // functions and global variables.
967 //
968 // extern function declaration            -> .extern
969 // extern function definition             -> .visible
970 // external global variable with init     -> .visible
971 // external without init                  -> .extern
972 // appending                              -> not allowed, assert.
973 // for any linkage other than
974 // internal, private, linker_private,
975 // linker_private_weak, linker_private_weak_def_auto,
976 // we emit                                -> .weak.
977 
978 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
979                                            raw_ostream &O) {
980   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
981     if (V->hasExternalLinkage()) {
982       if (isa<GlobalVariable>(V)) {
983         const GlobalVariable *GVar = cast<GlobalVariable>(V);
984         if (GVar) {
985           if (GVar->hasInitializer())
986             O << ".visible ";
987           else
988             O << ".extern ";
989         }
990       } else if (V->isDeclaration())
991         O << ".extern ";
992       else
993         O << ".visible ";
994     } else if (V->hasAppendingLinkage()) {
995       std::string msg;
996       msg.append("Error: ");
997       msg.append("Symbol ");
998       if (V->hasName())
999         msg.append(std::string(V->getName()));
1000       msg.append("has unsupported appending linkage type");
1001       llvm_unreachable(msg.c_str());
1002     } else if (!V->hasInternalLinkage() &&
1003                !V->hasPrivateLinkage()) {
1004       O << ".weak ";
1005     }
1006   }
1007 }
1008 
1009 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1010                                          raw_ostream &O, bool processDemoted,
1011                                          const NVPTXSubtarget &STI) {
1012   // Skip meta data
1013   if (GVar->hasSection()) {
1014     if (GVar->getSection() == "llvm.metadata")
1015       return;
1016   }
1017 
1018   // Skip LLVM intrinsic global variables
1019   if (GVar->getName().starts_with("llvm.") ||
1020       GVar->getName().starts_with("nvvm."))
1021     return;
1022 
1023   const DataLayout &DL = getDataLayout();
1024 
1025   // GlobalVariables are always constant pointers themselves.
1026   PointerType *PTy = GVar->getType();
1027   Type *ETy = GVar->getValueType();
1028 
1029   if (GVar->hasExternalLinkage()) {
1030     if (GVar->hasInitializer())
1031       O << ".visible ";
1032     else
1033       O << ".extern ";
1034   } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1035              GVar->hasAvailableExternallyLinkage() ||
1036              GVar->hasCommonLinkage()) {
1037     O << ".weak ";
1038   }
1039 
1040   if (isTexture(*GVar)) {
1041     O << ".global .texref " << getTextureName(*GVar) << ";\n";
1042     return;
1043   }
1044 
1045   if (isSurface(*GVar)) {
1046     O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1047     return;
1048   }
1049 
1050   if (GVar->isDeclaration()) {
1051     // (extern) declarations, no definition or initializer
1052     // Currently the only known declaration is for an automatic __local
1053     // (.shared) promoted to global.
1054     emitPTXGlobalVariable(GVar, O, STI);
1055     O << ";\n";
1056     return;
1057   }
1058 
1059   if (isSampler(*GVar)) {
1060     O << ".global .samplerref " << getSamplerName(*GVar);
1061 
1062     const Constant *Initializer = nullptr;
1063     if (GVar->hasInitializer())
1064       Initializer = GVar->getInitializer();
1065     const ConstantInt *CI = nullptr;
1066     if (Initializer)
1067       CI = dyn_cast<ConstantInt>(Initializer);
1068     if (CI) {
1069       unsigned sample = CI->getZExtValue();
1070 
1071       O << " = { ";
1072 
1073       for (int i = 0,
1074                addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1075            i < 3; i++) {
1076         O << "addr_mode_" << i << " = ";
1077         switch (addr) {
1078         case 0:
1079           O << "wrap";
1080           break;
1081         case 1:
1082           O << "clamp_to_border";
1083           break;
1084         case 2:
1085           O << "clamp_to_edge";
1086           break;
1087         case 3:
1088           O << "wrap";
1089           break;
1090         case 4:
1091           O << "mirror";
1092           break;
1093         }
1094         O << ", ";
1095       }
1096       O << "filter_mode = ";
1097       switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1098       case 0:
1099         O << "nearest";
1100         break;
1101       case 1:
1102         O << "linear";
1103         break;
1104       case 2:
1105         llvm_unreachable("Anisotropic filtering is not supported");
1106       default:
1107         O << "nearest";
1108         break;
1109       }
1110       if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1111         O << ", force_unnormalized_coords = 1";
1112       }
1113       O << " }";
1114     }
1115 
1116     O << ";\n";
1117     return;
1118   }
1119 
1120   if (GVar->hasPrivateLinkage()) {
1121     if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1122       return;
1123 
1124     // FIXME - need better way (e.g. Metadata) to avoid generating this global
1125     if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1126       return;
1127     if (GVar->use_empty())
1128       return;
1129   }
1130 
1131   const Function *demotedFunc = nullptr;
1132   if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1133     O << "// " << GVar->getName() << " has been demoted\n";
1134     if (localDecls.find(demotedFunc) != localDecls.end())
1135       localDecls[demotedFunc].push_back(GVar);
1136     else {
1137       std::vector<const GlobalVariable *> temp;
1138       temp.push_back(GVar);
1139       localDecls[demotedFunc] = temp;
1140     }
1141     return;
1142   }
1143 
1144   O << ".";
1145   emitPTXAddressSpace(PTy->getAddressSpace(), O);
1146 
1147   if (isManaged(*GVar)) {
1148     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1149       report_fatal_error(
1150           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1151     }
1152     O << " .attribute(.managed)";
1153   }
1154 
1155   if (MaybeAlign A = GVar->getAlign())
1156     O << " .align " << A->value();
1157   else
1158     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1159 
1160   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1161       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1162     O << " .";
1163     // Special case: ABI requires that we use .u8 for predicates
1164     if (ETy->isIntegerTy(1))
1165       O << "u8";
1166     else
1167       O << getPTXFundamentalTypeStr(ETy, false);
1168     O << " ";
1169     getSymbol(GVar)->print(O, MAI);
1170 
1171     // Ptx allows variable initilization only for constant and global state
1172     // spaces.
1173     if (GVar->hasInitializer()) {
1174       if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1175           (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1176         const Constant *Initializer = GVar->getInitializer();
1177         // 'undef' is treated as there is no value specified.
1178         if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1179           O << " = ";
1180           printScalarConstant(Initializer, O);
1181         }
1182       } else {
1183         // The frontend adds zero-initializer to device and constant variables
1184         // that don't have an initial value, and UndefValue to shared
1185         // variables, so skip warning for this case.
1186         if (!GVar->getInitializer()->isNullValue() &&
1187             !isa<UndefValue>(GVar->getInitializer())) {
1188           report_fatal_error("initial value of '" + GVar->getName() +
1189                              "' is not allowed in addrspace(" +
1190                              Twine(PTy->getAddressSpace()) + ")");
1191         }
1192       }
1193     }
1194   } else {
1195     uint64_t ElementSize = 0;
1196 
1197     // Although PTX has direct support for struct type and array type and
1198     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1199     // targets that support these high level field accesses. Structs, arrays
1200     // and vectors are lowered into arrays of bytes.
1201     switch (ETy->getTypeID()) {
1202     case Type::IntegerTyID: // Integers larger than 64 bits
1203     case Type::StructTyID:
1204     case Type::ArrayTyID:
1205     case Type::FixedVectorTyID:
1206       ElementSize = DL.getTypeStoreSize(ETy);
1207       // Ptx allows variable initilization only for constant and
1208       // global state spaces.
1209       if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1210            (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1211           GVar->hasInitializer()) {
1212         const Constant *Initializer = GVar->getInitializer();
1213         if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1214           AggBuffer aggBuffer(ElementSize, *this);
1215           bufferAggregateConstant(Initializer, &aggBuffer);
1216           if (aggBuffer.numSymbols()) {
1217             unsigned int ptrSize = MAI->getCodePointerSize();
1218             if (ElementSize % ptrSize ||
1219                 !aggBuffer.allSymbolsAligned(ptrSize)) {
1220               // Print in bytes and use the mask() operator for pointers.
1221               if (!STI.hasMaskOperator())
1222                 report_fatal_error(
1223                     "initialized packed aggregate with pointers '" +
1224                     GVar->getName() +
1225                     "' requires at least PTX ISA version 7.1");
1226               O << " .u8 ";
1227               getSymbol(GVar)->print(O, MAI);
1228               O << "[" << ElementSize << "] = {";
1229               aggBuffer.printBytes(O);
1230               O << "}";
1231             } else {
1232               O << " .u" << ptrSize * 8 << " ";
1233               getSymbol(GVar)->print(O, MAI);
1234               O << "[" << ElementSize / ptrSize << "] = {";
1235               aggBuffer.printWords(O);
1236               O << "}";
1237             }
1238           } else {
1239             O << " .b8 ";
1240             getSymbol(GVar)->print(O, MAI);
1241             O << "[" << ElementSize << "] = {";
1242             aggBuffer.printBytes(O);
1243             O << "}";
1244           }
1245         } else {
1246           O << " .b8 ";
1247           getSymbol(GVar)->print(O, MAI);
1248           if (ElementSize) {
1249             O << "[";
1250             O << ElementSize;
1251             O << "]";
1252           }
1253         }
1254       } else {
1255         O << " .b8 ";
1256         getSymbol(GVar)->print(O, MAI);
1257         if (ElementSize) {
1258           O << "[";
1259           O << ElementSize;
1260           O << "]";
1261         }
1262       }
1263       break;
1264     default:
1265       llvm_unreachable("type not supported yet");
1266     }
1267   }
1268   O << ";\n";
1269 }
1270 
1271 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1272   const Value *v = Symbols[nSym];
1273   const Value *v0 = SymbolsBeforeStripping[nSym];
1274   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1275     MCSymbol *Name = AP.getSymbol(GVar);
1276     PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1277     // Is v0 a generic pointer?
1278     bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1279     if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1280       os << "generic(";
1281       Name->print(os, AP.MAI);
1282       os << ")";
1283     } else {
1284       Name->print(os, AP.MAI);
1285     }
1286   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1287     const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1288     AP.printMCExpr(*Expr, os);
1289   } else
1290     llvm_unreachable("symbol type unknown");
1291 }
1292 
1293 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1294   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1295   symbolPosInBuffer.push_back(size);
1296   unsigned int nSym = 0;
1297   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1298   for (unsigned int pos = 0; pos < size;) {
1299     if (pos)
1300       os << ", ";
1301     if (pos != nextSymbolPos) {
1302       os << (unsigned int)buffer[pos];
1303       ++pos;
1304       continue;
1305     }
1306     // Generate a per-byte mask() operator for the symbol, which looks like:
1307     //   .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1308     // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1309     std::string symText;
1310     llvm::raw_string_ostream oss(symText);
1311     printSymbol(nSym, oss);
1312     for (unsigned i = 0; i < ptrSize; ++i) {
1313       if (i)
1314         os << ", ";
1315       llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1316       os << "(" << symText << ")";
1317     }
1318     pos += ptrSize;
1319     nextSymbolPos = symbolPosInBuffer[++nSym];
1320     assert(nextSymbolPos >= pos);
1321   }
1322 }
1323 
1324 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1325   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1326   symbolPosInBuffer.push_back(size);
1327   unsigned int nSym = 0;
1328   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1329   assert(nextSymbolPos % ptrSize == 0);
1330   for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1331     if (pos)
1332       os << ", ";
1333     if (pos == nextSymbolPos) {
1334       printSymbol(nSym, os);
1335       nextSymbolPos = symbolPosInBuffer[++nSym];
1336       assert(nextSymbolPos % ptrSize == 0);
1337       assert(nextSymbolPos >= pos + ptrSize);
1338     } else if (ptrSize == 4)
1339       os << support::endian::read32le(&buffer[pos]);
1340     else
1341       os << support::endian::read64le(&buffer[pos]);
1342   }
1343 }
1344 
1345 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1346   if (localDecls.find(f) == localDecls.end())
1347     return;
1348 
1349   std::vector<const GlobalVariable *> &gvars = localDecls[f];
1350 
1351   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1352   const NVPTXSubtarget &STI =
1353       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1354 
1355   for (const GlobalVariable *GV : gvars) {
1356     O << "\t// demoted variable\n\t";
1357     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1358   }
1359 }
1360 
1361 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1362                                           raw_ostream &O) const {
1363   switch (AddressSpace) {
1364   case ADDRESS_SPACE_LOCAL:
1365     O << "local";
1366     break;
1367   case ADDRESS_SPACE_GLOBAL:
1368     O << "global";
1369     break;
1370   case ADDRESS_SPACE_CONST:
1371     O << "const";
1372     break;
1373   case ADDRESS_SPACE_SHARED:
1374     O << "shared";
1375     break;
1376   default:
1377     report_fatal_error("Bad address space found while emitting PTX: " +
1378                        llvm::Twine(AddressSpace));
1379     break;
1380   }
1381 }
1382 
1383 std::string
1384 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1385   switch (Ty->getTypeID()) {
1386   case Type::IntegerTyID: {
1387     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1388     if (NumBits == 1)
1389       return "pred";
1390     else if (NumBits <= 64) {
1391       std::string name = "u";
1392       return name + utostr(NumBits);
1393     } else {
1394       llvm_unreachable("Integer too large");
1395       break;
1396     }
1397     break;
1398   }
1399   case Type::BFloatTyID:
1400   case Type::HalfTyID:
1401     // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1402     // PTX assembly.
1403     return "b16";
1404   case Type::FloatTyID:
1405     return "f32";
1406   case Type::DoubleTyID:
1407     return "f64";
1408   case Type::PointerTyID: {
1409     unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1410     assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1411 
1412     if (PtrSize == 64)
1413       if (useB4PTR)
1414         return "b64";
1415       else
1416         return "u64";
1417     else if (useB4PTR)
1418       return "b32";
1419     else
1420       return "u32";
1421   }
1422   default:
1423     break;
1424   }
1425   llvm_unreachable("unexpected type");
1426 }
1427 
1428 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1429                                             raw_ostream &O,
1430                                             const NVPTXSubtarget &STI) {
1431   const DataLayout &DL = getDataLayout();
1432 
1433   // GlobalVariables are always constant pointers themselves.
1434   Type *ETy = GVar->getValueType();
1435 
1436   O << ".";
1437   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1438   if (isManaged(*GVar)) {
1439     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1440       report_fatal_error(
1441           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1442     }
1443     O << " .attribute(.managed)";
1444   }
1445   if (MaybeAlign A = GVar->getAlign())
1446     O << " .align " << A->value();
1447   else
1448     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1449 
1450   // Special case for i128
1451   if (ETy->isIntegerTy(128)) {
1452     O << " .b8 ";
1453     getSymbol(GVar)->print(O, MAI);
1454     O << "[16]";
1455     return;
1456   }
1457 
1458   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1459     O << " .";
1460     O << getPTXFundamentalTypeStr(ETy);
1461     O << " ";
1462     getSymbol(GVar)->print(O, MAI);
1463     return;
1464   }
1465 
1466   int64_t ElementSize = 0;
1467 
1468   // Although PTX has direct support for struct type and array type and LLVM IR
1469   // is very similar to PTX, the LLVM CodeGen does not support for targets that
1470   // support these high level field accesses. Structs and arrays are lowered
1471   // into arrays of bytes.
1472   switch (ETy->getTypeID()) {
1473   case Type::StructTyID:
1474   case Type::ArrayTyID:
1475   case Type::FixedVectorTyID:
1476     ElementSize = DL.getTypeStoreSize(ETy);
1477     O << " .b8 ";
1478     getSymbol(GVar)->print(O, MAI);
1479     O << "[";
1480     if (ElementSize) {
1481       O << ElementSize;
1482     }
1483     O << "]";
1484     break;
1485   default:
1486     llvm_unreachable("type not supported yet");
1487   }
1488 }
1489 
1490 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1491   const DataLayout &DL = getDataLayout();
1492   const AttributeList &PAL = F->getAttributes();
1493   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1494   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1495 
1496   Function::const_arg_iterator I, E;
1497   unsigned paramIndex = 0;
1498   bool first = true;
1499   bool isKernelFunc = isKernelFunction(*F);
1500   bool isABI = (STI.getSmVersion() >= 20);
1501   bool hasImageHandles = STI.hasImageHandles();
1502 
1503   if (F->arg_empty() && !F->isVarArg()) {
1504     O << "()";
1505     return;
1506   }
1507 
1508   O << "(\n";
1509 
1510   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1511     Type *Ty = I->getType();
1512 
1513     if (!first)
1514       O << ",\n";
1515 
1516     first = false;
1517 
1518     // Handle image/sampler parameters
1519     if (isKernelFunction(*F)) {
1520       if (isSampler(*I) || isImage(*I)) {
1521         if (isImage(*I)) {
1522           std::string sname = std::string(I->getName());
1523           if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1524             if (hasImageHandles)
1525               O << "\t.param .u64 .ptr .surfref ";
1526             else
1527               O << "\t.param .surfref ";
1528             O << TLI->getParamName(F, paramIndex);
1529           }
1530           else { // Default image is read_only
1531             if (hasImageHandles)
1532               O << "\t.param .u64 .ptr .texref ";
1533             else
1534               O << "\t.param .texref ";
1535             O << TLI->getParamName(F, paramIndex);
1536           }
1537         } else {
1538           if (hasImageHandles)
1539             O << "\t.param .u64 .ptr .samplerref ";
1540           else
1541             O << "\t.param .samplerref ";
1542           O << TLI->getParamName(F, paramIndex);
1543         }
1544         continue;
1545       }
1546     }
1547 
1548     auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1549                                     paramIndex](Type *Ty) -> Align {
1550       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1551       MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1552       return std::max(TypeAlign, ParamAlign.valueOrOne());
1553     };
1554 
1555     if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1556       if (ShouldPassAsArray(Ty)) {
1557         // Just print .param .align <a> .b8 .param[size];
1558         // <a>  = optimal alignment for the element type; always multiple of
1559         //        PAL.getParamAlignment
1560         // size = typeallocsize of element type
1561         Align OptimalAlign = getOptimalAlignForParam(Ty);
1562 
1563         O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1564         O << TLI->getParamName(F, paramIndex);
1565         O << "[" << DL.getTypeAllocSize(Ty) << "]";
1566 
1567         continue;
1568       }
1569       // Just a scalar
1570       auto *PTy = dyn_cast<PointerType>(Ty);
1571       unsigned PTySizeInBits = 0;
1572       if (PTy) {
1573         PTySizeInBits =
1574             TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1575         assert(PTySizeInBits && "Invalid pointer size");
1576       }
1577 
1578       if (isKernelFunc) {
1579         if (PTy) {
1580           // Special handling for pointer arguments to kernel
1581           O << "\t.param .u" << PTySizeInBits << " ";
1582 
1583           if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1584               NVPTX::CUDA) {
1585             int addrSpace = PTy->getAddressSpace();
1586             switch (addrSpace) {
1587             default:
1588               O << ".ptr ";
1589               break;
1590             case ADDRESS_SPACE_CONST:
1591               O << ".ptr .const ";
1592               break;
1593             case ADDRESS_SPACE_SHARED:
1594               O << ".ptr .shared ";
1595               break;
1596             case ADDRESS_SPACE_GLOBAL:
1597               O << ".ptr .global ";
1598               break;
1599             }
1600             Align ParamAlign = I->getParamAlign().valueOrOne();
1601             O << ".align " << ParamAlign.value() << " ";
1602           }
1603           O << TLI->getParamName(F, paramIndex);
1604           continue;
1605         }
1606 
1607         // non-pointer scalar to kernel func
1608         O << "\t.param .";
1609         // Special case: predicate operands become .u8 types
1610         if (Ty->isIntegerTy(1))
1611           O << "u8";
1612         else
1613           O << getPTXFundamentalTypeStr(Ty);
1614         O << " ";
1615         O << TLI->getParamName(F, paramIndex);
1616         continue;
1617       }
1618       // Non-kernel function, just print .param .b<size> for ABI
1619       // and .reg .b<size> for non-ABI
1620       unsigned sz = 0;
1621       if (isa<IntegerType>(Ty)) {
1622         sz = cast<IntegerType>(Ty)->getBitWidth();
1623         sz = promoteScalarArgumentSize(sz);
1624       } else if (PTy) {
1625         assert(PTySizeInBits && "Invalid pointer size");
1626         sz = PTySizeInBits;
1627       } else
1628         sz = Ty->getPrimitiveSizeInBits();
1629       if (isABI)
1630         O << "\t.param .b" << sz << " ";
1631       else
1632         O << "\t.reg .b" << sz << " ";
1633       O << TLI->getParamName(F, paramIndex);
1634       continue;
1635     }
1636 
1637     // param has byVal attribute.
1638     Type *ETy = PAL.getParamByValType(paramIndex);
1639     assert(ETy && "Param should have byval type");
1640 
1641     if (isABI || isKernelFunc) {
1642       // Just print .param .align <a> .b8 .param[size];
1643       // <a>  = optimal alignment for the element type; always multiple of
1644       //        PAL.getParamAlignment
1645       // size = typeallocsize of element type
1646       Align OptimalAlign =
1647           isKernelFunc
1648               ? getOptimalAlignForParam(ETy)
1649               : TLI->getFunctionByValParamAlign(
1650                     F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1651 
1652       unsigned sz = DL.getTypeAllocSize(ETy);
1653       O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1654       O << TLI->getParamName(F, paramIndex);
1655       O << "[" << sz << "]";
1656       continue;
1657     } else {
1658       // Split the ETy into constituent parts and
1659       // print .param .b<size> <name> for each part.
1660       // Further, if a part is vector, print the above for
1661       // each vector element.
1662       SmallVector<EVT, 16> vtparts;
1663       ComputeValueVTs(*TLI, DL, ETy, vtparts);
1664       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1665         unsigned elems = 1;
1666         EVT elemtype = vtparts[i];
1667         if (vtparts[i].isVector()) {
1668           elems = vtparts[i].getVectorNumElements();
1669           elemtype = vtparts[i].getVectorElementType();
1670         }
1671 
1672         for (unsigned j = 0, je = elems; j != je; ++j) {
1673           unsigned sz = elemtype.getSizeInBits();
1674           if (elemtype.isInteger())
1675             sz = promoteScalarArgumentSize(sz);
1676           O << "\t.reg .b" << sz << " ";
1677           O << TLI->getParamName(F, paramIndex);
1678           if (j < je - 1)
1679             O << ",\n";
1680           ++paramIndex;
1681         }
1682         if (i < e - 1)
1683           O << ",\n";
1684       }
1685       --paramIndex;
1686       continue;
1687     }
1688   }
1689 
1690   if (F->isVarArg()) {
1691     if (!first)
1692       O << ",\n";
1693     O << "\t.param .align " << STI.getMaxRequiredAlignment();
1694     O << " .b8 ";
1695     O << TLI->getParamName(F, /* vararg */ -1) << "[]";
1696   }
1697 
1698   O << "\n)";
1699 }
1700 
1701 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1702     const MachineFunction &MF) {
1703   SmallString<128> Str;
1704   raw_svector_ostream O(Str);
1705 
1706   // Map the global virtual register number to a register class specific
1707   // virtual register number starting from 1 with that class.
1708   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1709   //unsigned numRegClasses = TRI->getNumRegClasses();
1710 
1711   // Emit the Fake Stack Object
1712   const MachineFrameInfo &MFI = MF.getFrameInfo();
1713   int NumBytes = (int) MFI.getStackSize();
1714   if (NumBytes) {
1715     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1716       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1717     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1718       O << "\t.reg .b64 \t%SP;\n";
1719       O << "\t.reg .b64 \t%SPL;\n";
1720     } else {
1721       O << "\t.reg .b32 \t%SP;\n";
1722       O << "\t.reg .b32 \t%SPL;\n";
1723     }
1724   }
1725 
1726   // Go through all virtual registers to establish the mapping between the
1727   // global virtual
1728   // register number and the per class virtual register number.
1729   // We use the per class virtual register number in the ptx output.
1730   unsigned int numVRs = MRI->getNumVirtRegs();
1731   for (unsigned i = 0; i < numVRs; i++) {
1732     Register vr = Register::index2VirtReg(i);
1733     const TargetRegisterClass *RC = MRI->getRegClass(vr);
1734     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1735     int n = regmap.size();
1736     regmap.insert(std::make_pair(vr, n + 1));
1737   }
1738 
1739   // Emit register declarations
1740   // @TODO: Extract out the real register usage
1741   // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1742   // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1743   // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1744   // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1745   // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1746   // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1747   // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1748 
1749   // Emit declaration of the virtual registers or 'physical' registers for
1750   // each register class
1751   for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1752     const TargetRegisterClass *RC = TRI->getRegClass(i);
1753     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1754     std::string rcname = getNVPTXRegClassName(RC);
1755     std::string rcStr = getNVPTXRegClassStr(RC);
1756     int n = regmap.size();
1757 
1758     // Only declare those registers that may be used.
1759     if (n) {
1760        O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1761          << ">;\n";
1762     }
1763   }
1764 
1765   OutStreamer->emitRawText(O.str());
1766 }
1767 
1768 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1769   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1770   bool ignored;
1771   unsigned int numHex;
1772   const char *lead;
1773 
1774   if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1775     numHex = 8;
1776     lead = "0f";
1777     APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1778   } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1779     numHex = 16;
1780     lead = "0d";
1781     APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1782   } else
1783     llvm_unreachable("unsupported fp type");
1784 
1785   APInt API = APF.bitcastToAPInt();
1786   O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1787 }
1788 
1789 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1790   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1791     O << CI->getValue();
1792     return;
1793   }
1794   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1795     printFPConstant(CFP, O);
1796     return;
1797   }
1798   if (isa<ConstantPointerNull>(CPV)) {
1799     O << "0";
1800     return;
1801   }
1802   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1803     bool IsNonGenericPointer = false;
1804     if (GVar->getType()->getAddressSpace() != 0) {
1805       IsNonGenericPointer = true;
1806     }
1807     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1808       O << "generic(";
1809       getSymbol(GVar)->print(O, MAI);
1810       O << ")";
1811     } else {
1812       getSymbol(GVar)->print(O, MAI);
1813     }
1814     return;
1815   }
1816   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1817     const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1818     printMCExpr(*E, O);
1819     return;
1820   }
1821   llvm_unreachable("Not scalar type found in printScalarConstant()");
1822 }
1823 
1824 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1825                                    AggBuffer *AggBuffer) {
1826   const DataLayout &DL = getDataLayout();
1827   int AllocSize = DL.getTypeAllocSize(CPV->getType());
1828   if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1829     // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1830     // only the space allocated by CPV.
1831     AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1832     return;
1833   }
1834 
1835   // Helper for filling AggBuffer with APInts.
1836   auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1837     size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1838     SmallVector<unsigned char, 16> Buf(NumBytes);
1839     for (unsigned I = 0; I < NumBytes; ++I) {
1840       Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1841     }
1842     AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1843   };
1844 
1845   switch (CPV->getType()->getTypeID()) {
1846   case Type::IntegerTyID:
1847     if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1848       AddIntToBuffer(CI->getValue());
1849       break;
1850     }
1851     if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1852       if (const auto *CI =
1853               dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1854         AddIntToBuffer(CI->getValue());
1855         break;
1856       }
1857       if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1858         Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1859         AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1860         AggBuffer->addZeros(AllocSize);
1861         break;
1862       }
1863     }
1864     llvm_unreachable("unsupported integer const type");
1865     break;
1866 
1867   case Type::HalfTyID:
1868   case Type::BFloatTyID:
1869   case Type::FloatTyID:
1870   case Type::DoubleTyID:
1871     AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1872     break;
1873 
1874   case Type::PointerTyID: {
1875     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1876       AggBuffer->addSymbol(GVar, GVar);
1877     } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1878       const Value *v = Cexpr->stripPointerCasts();
1879       AggBuffer->addSymbol(v, Cexpr);
1880     }
1881     AggBuffer->addZeros(AllocSize);
1882     break;
1883   }
1884 
1885   case Type::ArrayTyID:
1886   case Type::FixedVectorTyID:
1887   case Type::StructTyID: {
1888     if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1889       bufferAggregateConstant(CPV, AggBuffer);
1890       if (Bytes > AllocSize)
1891         AggBuffer->addZeros(Bytes - AllocSize);
1892     } else if (isa<ConstantAggregateZero>(CPV))
1893       AggBuffer->addZeros(Bytes);
1894     else
1895       llvm_unreachable("Unexpected Constant type");
1896     break;
1897   }
1898 
1899   default:
1900     llvm_unreachable("unsupported type");
1901   }
1902 }
1903 
1904 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1905                                               AggBuffer *aggBuffer) {
1906   const DataLayout &DL = getDataLayout();
1907   int Bytes;
1908 
1909   // Integers of arbitrary width
1910   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1911     APInt Val = CI->getValue();
1912     for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1913       uint8_t Byte = Val.getLoBits(8).getZExtValue();
1914       aggBuffer->addBytes(&Byte, 1, 1);
1915       Val.lshrInPlace(8);
1916     }
1917     return;
1918   }
1919 
1920   // Old constants
1921   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1922     if (CPV->getNumOperands())
1923       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1924         bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1925     return;
1926   }
1927 
1928   if (const ConstantDataSequential *CDS =
1929           dyn_cast<ConstantDataSequential>(CPV)) {
1930     if (CDS->getNumElements())
1931       for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1932         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1933                      aggBuffer);
1934     return;
1935   }
1936 
1937   if (isa<ConstantStruct>(CPV)) {
1938     if (CPV->getNumOperands()) {
1939       StructType *ST = cast<StructType>(CPV->getType());
1940       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1941         if (i == (e - 1))
1942           Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1943                   DL.getTypeAllocSize(ST) -
1944                   DL.getStructLayout(ST)->getElementOffset(i);
1945         else
1946           Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1947                   DL.getStructLayout(ST)->getElementOffset(i);
1948         bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1949       }
1950     }
1951     return;
1952   }
1953   llvm_unreachable("unsupported constant type in printAggregateConstant()");
1954 }
1955 
1956 /// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
1957 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1958 /// expressions that are representable in PTX and create
1959 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1960 const MCExpr *
1961 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1962   MCContext &Ctx = OutContext;
1963 
1964   if (CV->isNullValue() || isa<UndefValue>(CV))
1965     return MCConstantExpr::create(0, Ctx);
1966 
1967   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1968     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1969 
1970   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1971     const MCSymbolRefExpr *Expr =
1972       MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1973     if (ProcessingGeneric) {
1974       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1975     } else {
1976       return Expr;
1977     }
1978   }
1979 
1980   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1981   if (!CE) {
1982     llvm_unreachable("Unknown constant value to lower!");
1983   }
1984 
1985   switch (CE->getOpcode()) {
1986   default:
1987     break; // Error
1988 
1989   case Instruction::AddrSpaceCast: {
1990     // Strip the addrspacecast and pass along the operand
1991     PointerType *DstTy = cast<PointerType>(CE->getType());
1992     if (DstTy->getAddressSpace() == 0)
1993       return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
1994 
1995     break; // Error
1996   }
1997 
1998   case Instruction::GetElementPtr: {
1999     const DataLayout &DL = getDataLayout();
2000 
2001     // Generate a symbolic expression for the byte address
2002     APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2003     cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2004 
2005     const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2006                                             ProcessingGeneric);
2007     if (!OffsetAI)
2008       return Base;
2009 
2010     int64_t Offset = OffsetAI.getSExtValue();
2011     return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2012                                    Ctx);
2013   }
2014 
2015   case Instruction::Trunc:
2016     // We emit the value and depend on the assembler to truncate the generated
2017     // expression properly.  This is important for differences between
2018     // blockaddress labels.  Since the two labels are in the same function, it
2019     // is reasonable to treat their delta as a 32-bit value.
2020     [[fallthrough]];
2021   case Instruction::BitCast:
2022     return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2023 
2024   case Instruction::IntToPtr: {
2025     const DataLayout &DL = getDataLayout();
2026 
2027     // Handle casts to pointers by changing them into casts to the appropriate
2028     // integer type.  This promotes constant folding and simplifies this code.
2029     Constant *Op = CE->getOperand(0);
2030     Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2031                                  /*IsSigned*/ false, DL);
2032     if (Op)
2033       return lowerConstantForGV(Op, ProcessingGeneric);
2034 
2035     break; // Error
2036   }
2037 
2038   case Instruction::PtrToInt: {
2039     const DataLayout &DL = getDataLayout();
2040 
2041     // Support only foldable casts to/from pointers that can be eliminated by
2042     // changing the pointer to the appropriately sized integer type.
2043     Constant *Op = CE->getOperand(0);
2044     Type *Ty = CE->getType();
2045 
2046     const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2047 
2048     // We can emit the pointer value into this slot if the slot is an
2049     // integer slot equal to the size of the pointer.
2050     if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2051       return OpExpr;
2052 
2053     // Otherwise the pointer is smaller than the resultant integer, mask off
2054     // the high bits so we are sure to get a proper truncation if the input is
2055     // a constant expr.
2056     unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2057     const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2058     return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2059   }
2060 
2061   // The MC library also has a right-shift operator, but it isn't consistently
2062   // signed or unsigned between different targets.
2063   case Instruction::Add: {
2064     const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2065     const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2066     switch (CE->getOpcode()) {
2067     default: llvm_unreachable("Unknown binary operator constant cast expr");
2068     case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2069     }
2070   }
2071   }
2072 
2073   // If the code isn't optimized, there may be outstanding folding
2074   // opportunities. Attempt to fold the expression using DataLayout as a
2075   // last resort before giving up.
2076   Constant *C = ConstantFoldConstant(CE, getDataLayout());
2077   if (C != CE)
2078     return lowerConstantForGV(C, ProcessingGeneric);
2079 
2080   // Otherwise report the problem to the user.
2081   std::string S;
2082   raw_string_ostream OS(S);
2083   OS << "Unsupported expression in static initializer: ";
2084   CE->printAsOperand(OS, /*PrintType=*/false,
2085                  !MF ? nullptr : MF->getFunction().getParent());
2086   report_fatal_error(Twine(OS.str()));
2087 }
2088 
2089 // Copy of MCExpr::print customized for NVPTX
2090 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2091   switch (Expr.getKind()) {
2092   case MCExpr::Target:
2093     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2094   case MCExpr::Constant:
2095     OS << cast<MCConstantExpr>(Expr).getValue();
2096     return;
2097 
2098   case MCExpr::SymbolRef: {
2099     const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2100     const MCSymbol &Sym = SRE.getSymbol();
2101     Sym.print(OS, MAI);
2102     return;
2103   }
2104 
2105   case MCExpr::Unary: {
2106     const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2107     switch (UE.getOpcode()) {
2108     case MCUnaryExpr::LNot:  OS << '!'; break;
2109     case MCUnaryExpr::Minus: OS << '-'; break;
2110     case MCUnaryExpr::Not:   OS << '~'; break;
2111     case MCUnaryExpr::Plus:  OS << '+'; break;
2112     }
2113     printMCExpr(*UE.getSubExpr(), OS);
2114     return;
2115   }
2116 
2117   case MCExpr::Binary: {
2118     const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2119 
2120     // Only print parens around the LHS if it is non-trivial.
2121     if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2122         isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2123       printMCExpr(*BE.getLHS(), OS);
2124     } else {
2125       OS << '(';
2126       printMCExpr(*BE.getLHS(), OS);
2127       OS<< ')';
2128     }
2129 
2130     switch (BE.getOpcode()) {
2131     case MCBinaryExpr::Add:
2132       // Print "X-42" instead of "X+-42".
2133       if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2134         if (RHSC->getValue() < 0) {
2135           OS << RHSC->getValue();
2136           return;
2137         }
2138       }
2139 
2140       OS <<  '+';
2141       break;
2142     default: llvm_unreachable("Unhandled binary operator");
2143     }
2144 
2145     // Only print parens around the LHS if it is non-trivial.
2146     if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2147       printMCExpr(*BE.getRHS(), OS);
2148     } else {
2149       OS << '(';
2150       printMCExpr(*BE.getRHS(), OS);
2151       OS << ')';
2152     }
2153     return;
2154   }
2155   }
2156 
2157   llvm_unreachable("Invalid expression kind!");
2158 }
2159 
2160 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2161 ///
2162 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2163                                       const char *ExtraCode, raw_ostream &O) {
2164   if (ExtraCode && ExtraCode[0]) {
2165     if (ExtraCode[1] != 0)
2166       return true; // Unknown modifier.
2167 
2168     switch (ExtraCode[0]) {
2169     default:
2170       // See if this is a generic print operand
2171       return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2172     case 'r':
2173       break;
2174     }
2175   }
2176 
2177   printOperand(MI, OpNo, O);
2178 
2179   return false;
2180 }
2181 
2182 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2183                                             unsigned OpNo,
2184                                             const char *ExtraCode,
2185                                             raw_ostream &O) {
2186   if (ExtraCode && ExtraCode[0])
2187     return true; // Unknown modifier
2188 
2189   O << '[';
2190   printMemOperand(MI, OpNo, O);
2191   O << ']';
2192 
2193   return false;
2194 }
2195 
2196 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
2197                                    raw_ostream &O) {
2198   const MachineOperand &MO = MI->getOperand(OpNum);
2199   switch (MO.getType()) {
2200   case MachineOperand::MO_Register:
2201     if (MO.getReg().isPhysical()) {
2202       if (MO.getReg() == NVPTX::VRDepot)
2203         O << DEPOTNAME << getFunctionNumber();
2204       else
2205         O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2206     } else {
2207       emitVirtualRegister(MO.getReg(), O);
2208     }
2209     break;
2210 
2211   case MachineOperand::MO_Immediate:
2212     O << MO.getImm();
2213     break;
2214 
2215   case MachineOperand::MO_FPImmediate:
2216     printFPConstant(MO.getFPImm(), O);
2217     break;
2218 
2219   case MachineOperand::MO_GlobalAddress:
2220     PrintSymbolOperand(MO, O);
2221     break;
2222 
2223   case MachineOperand::MO_MachineBasicBlock:
2224     MO.getMBB()->getSymbol()->print(O, MAI);
2225     break;
2226 
2227   default:
2228     llvm_unreachable("Operand type not supported.");
2229   }
2230 }
2231 
2232 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
2233                                       raw_ostream &O, const char *Modifier) {
2234   printOperand(MI, OpNum, O);
2235 
2236   if (Modifier && strcmp(Modifier, "add") == 0) {
2237     O << ", ";
2238     printOperand(MI, OpNum + 1, O);
2239   } else {
2240     if (MI->getOperand(OpNum + 1).isImm() &&
2241         MI->getOperand(OpNum + 1).getImm() == 0)
2242       return; // don't print ',0' or '+0'
2243     O << "+";
2244     printOperand(MI, OpNum + 1, O);
2245   }
2246 }
2247 
2248 // Force static initialization.
2249 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2250   RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2251   RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2252 }
2253