1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Triple.h"
37 #include "llvm/ADT/Twine.h"
38 #include "llvm/Analysis/ConstantFolding.h"
39 #include "llvm/CodeGen/Analysis.h"
40 #include "llvm/CodeGen/MachineBasicBlock.h"
41 #include "llvm/CodeGen/MachineFrameInfo.h"
42 #include "llvm/CodeGen/MachineFunction.h"
43 #include "llvm/CodeGen/MachineInstr.h"
44 #include "llvm/CodeGen/MachineLoopInfo.h"
45 #include "llvm/CodeGen/MachineModuleInfo.h"
46 #include "llvm/CodeGen/MachineOperand.h"
47 #include "llvm/CodeGen/MachineRegisterInfo.h"
48 #include "llvm/CodeGen/TargetLowering.h"
49 #include "llvm/CodeGen/TargetRegisterInfo.h"
50 #include "llvm/CodeGen/ValueTypes.h"
51 #include "llvm/IR/Attributes.h"
52 #include "llvm/IR/BasicBlock.h"
53 #include "llvm/IR/Constant.h"
54 #include "llvm/IR/Constants.h"
55 #include "llvm/IR/DataLayout.h"
56 #include "llvm/IR/DebugInfo.h"
57 #include "llvm/IR/DebugInfoMetadata.h"
58 #include "llvm/IR/DebugLoc.h"
59 #include "llvm/IR/DerivedTypes.h"
60 #include "llvm/IR/Function.h"
61 #include "llvm/IR/GlobalValue.h"
62 #include "llvm/IR/GlobalVariable.h"
63 #include "llvm/IR/Instruction.h"
64 #include "llvm/IR/LLVMContext.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/Operator.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/User.h"
69 #include "llvm/MC/MCExpr.h"
70 #include "llvm/MC/MCInst.h"
71 #include "llvm/MC/MCInstrDesc.h"
72 #include "llvm/MC/MCStreamer.h"
73 #include "llvm/MC/MCSymbol.h"
74 #include "llvm/MC/TargetRegistry.h"
75 #include "llvm/Support/Casting.h"
76 #include "llvm/Support/CommandLine.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/MachineValueType.h"
79 #include "llvm/Support/Path.h"
80 #include "llvm/Support/raw_ostream.h"
81 #include "llvm/Target/TargetLoweringObjectFile.h"
82 #include "llvm/Target/TargetMachine.h"
83 #include "llvm/Transforms/Utils/UnrollLoop.h"
84 #include <cassert>
85 #include <cstdint>
86 #include <cstring>
87 #include <new>
88 #include <string>
89 #include <utility>
90 #include <vector>
91 
92 using namespace llvm;
93 
94 #define DEPOTNAME "__local_depot"
95 
96 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
97 /// depends.
98 static void
99 DiscoverDependentGlobals(const Value *V,
100                          DenseSet<const GlobalVariable *> &Globals) {
101   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
102     Globals.insert(GV);
103   else {
104     if (const User *U = dyn_cast<User>(V)) {
105       for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
106         DiscoverDependentGlobals(U->getOperand(i), Globals);
107       }
108     }
109   }
110 }
111 
112 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
113 /// instances to be emitted, but only after any dependents have been added
114 /// first.s
115 static void
116 VisitGlobalVariableForEmission(const GlobalVariable *GV,
117                                SmallVectorImpl<const GlobalVariable *> &Order,
118                                DenseSet<const GlobalVariable *> &Visited,
119                                DenseSet<const GlobalVariable *> &Visiting) {
120   // Have we already visited this one?
121   if (Visited.count(GV))
122     return;
123 
124   // Do we have a circular dependency?
125   if (!Visiting.insert(GV).second)
126     report_fatal_error("Circular dependency found in global variable set");
127 
128   // Make sure we visit all dependents first
129   DenseSet<const GlobalVariable *> Others;
130   for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
131     DiscoverDependentGlobals(GV->getOperand(i), Others);
132 
133   for (const GlobalVariable *GV : Others)
134     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
135 
136   // Now we can visit ourself
137   Order.push_back(GV);
138   Visited.insert(GV);
139   Visiting.erase(GV);
140 }
141 
142 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
143   MCInst Inst;
144   lowerToMCInst(MI, Inst);
145   EmitToStreamer(*OutStreamer, Inst);
146 }
147 
148 // Handle symbol backtracking for targets that do not support image handles
149 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
150                                            unsigned OpNo, MCOperand &MCOp) {
151   const MachineOperand &MO = MI->getOperand(OpNo);
152   const MCInstrDesc &MCID = MI->getDesc();
153 
154   if (MCID.TSFlags & NVPTXII::IsTexFlag) {
155     // This is a texture fetch, so operand 4 is a texref and operand 5 is
156     // a samplerref
157     if (OpNo == 4 && MO.isImm()) {
158       lowerImageHandleSymbol(MO.getImm(), MCOp);
159       return true;
160     }
161     if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
162       lowerImageHandleSymbol(MO.getImm(), MCOp);
163       return true;
164     }
165 
166     return false;
167   } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
168     unsigned VecSize =
169       1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
170 
171     // For a surface load of vector size N, the Nth operand will be the surfref
172     if (OpNo == VecSize && MO.isImm()) {
173       lowerImageHandleSymbol(MO.getImm(), MCOp);
174       return true;
175     }
176 
177     return false;
178   } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
179     // This is a surface store, so operand 0 is a surfref
180     if (OpNo == 0 && MO.isImm()) {
181       lowerImageHandleSymbol(MO.getImm(), MCOp);
182       return true;
183     }
184 
185     return false;
186   } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
187     // This is a query, so operand 1 is a surfref/texref
188     if (OpNo == 1 && MO.isImm()) {
189       lowerImageHandleSymbol(MO.getImm(), MCOp);
190       return true;
191     }
192 
193     return false;
194   }
195 
196   return false;
197 }
198 
199 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
200   // Ewwww
201   LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
202   NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
203   const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
204   const char *Sym = MFI->getImageHandleSymbol(Index);
205   std::string *SymNamePtr =
206     nvTM.getManagedStrPool()->getManagedString(Sym);
207   MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(StringRef(*SymNamePtr)));
208 }
209 
210 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
211   OutMI.setOpcode(MI->getOpcode());
212   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
213   if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
214     const MachineOperand &MO = MI->getOperand(0);
215     OutMI.addOperand(GetSymbolRef(
216       OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
217     return;
218   }
219 
220   const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
221   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
222     const MachineOperand &MO = MI->getOperand(i);
223 
224     MCOperand MCOp;
225     if (!STI.hasImageHandles()) {
226       if (lowerImageHandleOperand(MI, i, MCOp)) {
227         OutMI.addOperand(MCOp);
228         continue;
229       }
230     }
231 
232     if (lowerOperand(MO, MCOp))
233       OutMI.addOperand(MCOp);
234   }
235 }
236 
237 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
238                                    MCOperand &MCOp) {
239   switch (MO.getType()) {
240   default: llvm_unreachable("unknown operand type");
241   case MachineOperand::MO_Register:
242     MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
243     break;
244   case MachineOperand::MO_Immediate:
245     MCOp = MCOperand::createImm(MO.getImm());
246     break;
247   case MachineOperand::MO_MachineBasicBlock:
248     MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
249         MO.getMBB()->getSymbol(), OutContext));
250     break;
251   case MachineOperand::MO_ExternalSymbol:
252     MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
253     break;
254   case MachineOperand::MO_GlobalAddress:
255     MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
256     break;
257   case MachineOperand::MO_FPImmediate: {
258     const ConstantFP *Cnt = MO.getFPImm();
259     const APFloat &Val = Cnt->getValueAPF();
260 
261     switch (Cnt->getType()->getTypeID()) {
262     default: report_fatal_error("Unsupported FP type"); break;
263     case Type::HalfTyID:
264       MCOp = MCOperand::createExpr(
265         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
266       break;
267     case Type::FloatTyID:
268       MCOp = MCOperand::createExpr(
269         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
270       break;
271     case Type::DoubleTyID:
272       MCOp = MCOperand::createExpr(
273         NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
274       break;
275     }
276     break;
277   }
278   }
279   return true;
280 }
281 
282 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
283   if (Register::isVirtualRegister(Reg)) {
284     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
285 
286     DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
287     unsigned RegNum = RegMap[Reg];
288 
289     // Encode the register class in the upper 4 bits
290     // Must be kept in sync with NVPTXInstPrinter::printRegName
291     unsigned Ret = 0;
292     if (RC == &NVPTX::Int1RegsRegClass) {
293       Ret = (1 << 28);
294     } else if (RC == &NVPTX::Int16RegsRegClass) {
295       Ret = (2 << 28);
296     } else if (RC == &NVPTX::Int32RegsRegClass) {
297       Ret = (3 << 28);
298     } else if (RC == &NVPTX::Int64RegsRegClass) {
299       Ret = (4 << 28);
300     } else if (RC == &NVPTX::Float32RegsRegClass) {
301       Ret = (5 << 28);
302     } else if (RC == &NVPTX::Float64RegsRegClass) {
303       Ret = (6 << 28);
304     } else if (RC == &NVPTX::Float16RegsRegClass) {
305       Ret = (7 << 28);
306     } else if (RC == &NVPTX::Float16x2RegsRegClass) {
307       Ret = (8 << 28);
308     } else {
309       report_fatal_error("Bad register class");
310     }
311 
312     // Insert the vreg number
313     Ret |= (RegNum & 0x0FFFFFFF);
314     return Ret;
315   } else {
316     // Some special-use registers are actually physical registers.
317     // Encode this as the register class ID of 0 and the real register ID.
318     return Reg & 0x0FFFFFFF;
319   }
320 }
321 
322 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
323   const MCExpr *Expr;
324   Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
325                                  OutContext);
326   return MCOperand::createExpr(Expr);
327 }
328 
329 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
330   const DataLayout &DL = getDataLayout();
331   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
332   const TargetLowering *TLI = STI.getTargetLowering();
333 
334   Type *Ty = F->getReturnType();
335 
336   bool isABI = (STI.getSmVersion() >= 20);
337 
338   if (Ty->getTypeID() == Type::VoidTyID)
339     return;
340 
341   O << " (";
342 
343   if (isABI) {
344     if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) {
345       unsigned size = 0;
346       if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
347         size = ITy->getBitWidth();
348       } else {
349         assert(Ty->isFloatingPointTy() && "Floating point type expected here");
350         size = Ty->getPrimitiveSizeInBits();
351       }
352       // PTX ABI requires all scalar return values to be at least 32
353       // bits in size.  fp16 normally uses .b16 as its storage type in
354       // PTX, so its size must be adjusted here, too.
355       if (size < 32)
356         size = 32;
357 
358       O << ".param .b" << size << " func_retval0";
359     } else if (isa<PointerType>(Ty)) {
360       O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
361         << " func_retval0";
362     } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
363       unsigned totalsz = DL.getTypeAllocSize(Ty);
364       unsigned retAlignment = 0;
365       if (!getAlign(*F, 0, retAlignment))
366         retAlignment = DL.getABITypeAlignment(Ty);
367       O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
368         << "]";
369     } else
370       llvm_unreachable("Unknown return type");
371   } else {
372     SmallVector<EVT, 16> vtparts;
373     ComputeValueVTs(*TLI, DL, Ty, vtparts);
374     unsigned idx = 0;
375     for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
376       unsigned elems = 1;
377       EVT elemtype = vtparts[i];
378       if (vtparts[i].isVector()) {
379         elems = vtparts[i].getVectorNumElements();
380         elemtype = vtparts[i].getVectorElementType();
381       }
382 
383       for (unsigned j = 0, je = elems; j != je; ++j) {
384         unsigned sz = elemtype.getSizeInBits();
385         if (elemtype.isInteger() && (sz < 32))
386           sz = 32;
387         O << ".reg .b" << sz << " func_retval" << idx;
388         if (j < je - 1)
389           O << ", ";
390         ++idx;
391       }
392       if (i < e - 1)
393         O << ", ";
394     }
395   }
396   O << ") ";
397 }
398 
399 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
400                                         raw_ostream &O) {
401   const Function &F = MF.getFunction();
402   printReturnValStr(&F, O);
403 }
404 
405 // Return true if MBB is the header of a loop marked with
406 // llvm.loop.unroll.disable.
407 // TODO: consider "#pragma unroll 1" which is equivalent to "#pragma nounroll".
408 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
409     const MachineBasicBlock &MBB) const {
410   MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
411   // We insert .pragma "nounroll" only to the loop header.
412   if (!LI.isLoopHeader(&MBB))
413     return false;
414 
415   // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
416   // we iterate through each back edge of the loop with header MBB, and check
417   // whether its metadata contains llvm.loop.unroll.disable.
418   for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
419     if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
420       // Edges from other loops to MBB are not back edges.
421       continue;
422     }
423     if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
424       if (MDNode *LoopID =
425               PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
426         if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
427           return true;
428       }
429     }
430   }
431   return false;
432 }
433 
434 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
435   AsmPrinter::emitBasicBlockStart(MBB);
436   if (isLoopHeaderOfNoUnroll(MBB))
437     OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
438 }
439 
440 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
441   SmallString<128> Str;
442   raw_svector_ostream O(Str);
443 
444   if (!GlobalsEmitted) {
445     emitGlobals(*MF->getFunction().getParent());
446     GlobalsEmitted = true;
447   }
448 
449   // Set up
450   MRI = &MF->getRegInfo();
451   F = &MF->getFunction();
452   emitLinkageDirective(F, O);
453   if (isKernelFunction(*F))
454     O << ".entry ";
455   else {
456     O << ".func ";
457     printReturnValStr(*MF, O);
458   }
459 
460   CurrentFnSym->print(O, MAI);
461 
462   emitFunctionParamList(*MF, O);
463 
464   if (isKernelFunction(*F))
465     emitKernelFunctionDirectives(*F, O);
466 
467   OutStreamer->emitRawText(O.str());
468 
469   VRegMapping.clear();
470   // Emit open brace for function body.
471   OutStreamer->emitRawText(StringRef("{\n"));
472   setAndEmitFunctionVirtualRegisters(*MF);
473   // Emit initial .loc debug directive for correct relocation symbol data.
474   if (MMI && MMI->hasDebugInfo())
475     emitInitialRawDwarfLocDirective(*MF);
476 }
477 
478 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
479   bool Result = AsmPrinter::runOnMachineFunction(F);
480   // Emit closing brace for the body of function F.
481   // The closing brace must be emitted here because we need to emit additional
482   // debug labels/data after the last basic block.
483   // We need to emit the closing brace here because we don't have function that
484   // finished emission of the function body.
485   OutStreamer->emitRawText(StringRef("}\n"));
486   return Result;
487 }
488 
489 void NVPTXAsmPrinter::emitFunctionBodyStart() {
490   SmallString<128> Str;
491   raw_svector_ostream O(Str);
492   emitDemotedVars(&MF->getFunction(), O);
493   OutStreamer->emitRawText(O.str());
494 }
495 
496 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
497   VRegMapping.clear();
498 }
499 
500 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
501     SmallString<128> Str;
502     raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
503     return OutContext.getOrCreateSymbol(Str);
504 }
505 
506 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
507   Register RegNo = MI->getOperand(0).getReg();
508   if (Register::isVirtualRegister(RegNo)) {
509     OutStreamer->AddComment(Twine("implicit-def: ") +
510                             getVirtualRegisterName(RegNo));
511   } else {
512     const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
513     OutStreamer->AddComment(Twine("implicit-def: ") +
514                             STI.getRegisterInfo()->getName(RegNo));
515   }
516   OutStreamer->AddBlankLine();
517 }
518 
519 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
520                                                    raw_ostream &O) const {
521   // If the NVVM IR has some of reqntid* specified, then output
522   // the reqntid directive, and set the unspecified ones to 1.
523   // If none of reqntid* is specified, don't output reqntid directive.
524   unsigned reqntidx, reqntidy, reqntidz;
525   bool specified = false;
526   if (!getReqNTIDx(F, reqntidx))
527     reqntidx = 1;
528   else
529     specified = true;
530   if (!getReqNTIDy(F, reqntidy))
531     reqntidy = 1;
532   else
533     specified = true;
534   if (!getReqNTIDz(F, reqntidz))
535     reqntidz = 1;
536   else
537     specified = true;
538 
539   if (specified)
540     O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
541       << "\n";
542 
543   // If the NVVM IR has some of maxntid* specified, then output
544   // the maxntid directive, and set the unspecified ones to 1.
545   // If none of maxntid* is specified, don't output maxntid directive.
546   unsigned maxntidx, maxntidy, maxntidz;
547   specified = false;
548   if (!getMaxNTIDx(F, maxntidx))
549     maxntidx = 1;
550   else
551     specified = true;
552   if (!getMaxNTIDy(F, maxntidy))
553     maxntidy = 1;
554   else
555     specified = true;
556   if (!getMaxNTIDz(F, maxntidz))
557     maxntidz = 1;
558   else
559     specified = true;
560 
561   if (specified)
562     O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
563       << "\n";
564 
565   unsigned mincta;
566   if (getMinCTASm(F, mincta))
567     O << ".minnctapersm " << mincta << "\n";
568 
569   unsigned maxnreg;
570   if (getMaxNReg(F, maxnreg))
571     O << ".maxnreg " << maxnreg << "\n";
572 }
573 
574 std::string
575 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
576   const TargetRegisterClass *RC = MRI->getRegClass(Reg);
577 
578   std::string Name;
579   raw_string_ostream NameStr(Name);
580 
581   VRegRCMap::const_iterator I = VRegMapping.find(RC);
582   assert(I != VRegMapping.end() && "Bad register class");
583   const DenseMap<unsigned, unsigned> &RegMap = I->second;
584 
585   VRegMap::const_iterator VI = RegMap.find(Reg);
586   assert(VI != RegMap.end() && "Bad virtual register");
587   unsigned MappedVR = VI->second;
588 
589   NameStr << getNVPTXRegClassStr(RC) << MappedVR;
590 
591   NameStr.flush();
592   return Name;
593 }
594 
595 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
596                                           raw_ostream &O) {
597   O << getVirtualRegisterName(vr);
598 }
599 
600 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
601   emitLinkageDirective(F, O);
602   if (isKernelFunction(*F))
603     O << ".entry ";
604   else
605     O << ".func ";
606   printReturnValStr(F, O);
607   getSymbol(F)->print(O, MAI);
608   O << "\n";
609   emitFunctionParamList(F, O);
610   O << ";\n";
611 }
612 
613 static bool usedInGlobalVarDef(const Constant *C) {
614   if (!C)
615     return false;
616 
617   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
618     return GV->getName() != "llvm.used";
619   }
620 
621   for (const User *U : C->users())
622     if (const Constant *C = dyn_cast<Constant>(U))
623       if (usedInGlobalVarDef(C))
624         return true;
625 
626   return false;
627 }
628 
629 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
630   if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
631     if (othergv->getName() == "llvm.used")
632       return true;
633   }
634 
635   if (const Instruction *instr = dyn_cast<Instruction>(U)) {
636     if (instr->getParent() && instr->getParent()->getParent()) {
637       const Function *curFunc = instr->getParent()->getParent();
638       if (oneFunc && (curFunc != oneFunc))
639         return false;
640       oneFunc = curFunc;
641       return true;
642     } else
643       return false;
644   }
645 
646   for (const User *UU : U->users())
647     if (!usedInOneFunc(UU, oneFunc))
648       return false;
649 
650   return true;
651 }
652 
653 /* Find out if a global variable can be demoted to local scope.
654  * Currently, this is valid for CUDA shared variables, which have local
655  * scope and global lifetime. So the conditions to check are :
656  * 1. Is the global variable in shared address space?
657  * 2. Does it have internal linkage?
658  * 3. Is the global variable referenced only in one function?
659  */
660 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
661   if (!gv->hasInternalLinkage())
662     return false;
663   PointerType *Pty = gv->getType();
664   if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
665     return false;
666 
667   const Function *oneFunc = nullptr;
668 
669   bool flag = usedInOneFunc(gv, oneFunc);
670   if (!flag)
671     return false;
672   if (!oneFunc)
673     return false;
674   f = oneFunc;
675   return true;
676 }
677 
678 static bool useFuncSeen(const Constant *C,
679                         DenseMap<const Function *, bool> &seenMap) {
680   for (const User *U : C->users()) {
681     if (const Constant *cu = dyn_cast<Constant>(U)) {
682       if (useFuncSeen(cu, seenMap))
683         return true;
684     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
685       const BasicBlock *bb = I->getParent();
686       if (!bb)
687         continue;
688       const Function *caller = bb->getParent();
689       if (!caller)
690         continue;
691       if (seenMap.find(caller) != seenMap.end())
692         return true;
693     }
694   }
695   return false;
696 }
697 
698 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
699   DenseMap<const Function *, bool> seenMap;
700   for (const Function &F : M) {
701     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
702       emitDeclaration(&F, O);
703       continue;
704     }
705 
706     if (F.isDeclaration()) {
707       if (F.use_empty())
708         continue;
709       if (F.getIntrinsicID())
710         continue;
711       emitDeclaration(&F, O);
712       continue;
713     }
714     for (const User *U : F.users()) {
715       if (const Constant *C = dyn_cast<Constant>(U)) {
716         if (usedInGlobalVarDef(C)) {
717           // The use is in the initialization of a global variable
718           // that is a function pointer, so print a declaration
719           // for the original function
720           emitDeclaration(&F, O);
721           break;
722         }
723         // Emit a declaration of this function if the function that
724         // uses this constant expr has already been seen.
725         if (useFuncSeen(C, seenMap)) {
726           emitDeclaration(&F, O);
727           break;
728         }
729       }
730 
731       if (!isa<Instruction>(U))
732         continue;
733       const Instruction *instr = cast<Instruction>(U);
734       const BasicBlock *bb = instr->getParent();
735       if (!bb)
736         continue;
737       const Function *caller = bb->getParent();
738       if (!caller)
739         continue;
740 
741       // If a caller has already been seen, then the caller is
742       // appearing in the module before the callee. so print out
743       // a declaration for the callee.
744       if (seenMap.find(caller) != seenMap.end()) {
745         emitDeclaration(&F, O);
746         break;
747       }
748     }
749     seenMap[&F] = true;
750   }
751 }
752 
753 static bool isEmptyXXStructor(GlobalVariable *GV) {
754   if (!GV) return true;
755   const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
756   if (!InitList) return true;  // Not an array; we don't know how to parse.
757   return InitList->getNumOperands() == 0;
758 }
759 
760 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
761   // Construct a default subtarget off of the TargetMachine defaults. The
762   // rest of NVPTX isn't friendly to change subtargets per function and
763   // so the default TargetMachine will have all of the options.
764   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
765   const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
766   SmallString<128> Str1;
767   raw_svector_ostream OS1(Str1);
768 
769   // Emit header before any dwarf directives are emitted below.
770   emitHeader(M, OS1, *STI);
771   OutStreamer->emitRawText(OS1.str());
772 }
773 
774 bool NVPTXAsmPrinter::doInitialization(Module &M) {
775   if (M.alias_size()) {
776     report_fatal_error("Module has aliases, which NVPTX does not support.");
777     return true; // error
778   }
779   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) {
780     report_fatal_error(
781         "Module has a nontrivial global ctor, which NVPTX does not support.");
782     return true;  // error
783   }
784   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) {
785     report_fatal_error(
786         "Module has a nontrivial global dtor, which NVPTX does not support.");
787     return true;  // error
788   }
789 
790   // We need to call the parent's one explicitly.
791   bool Result = AsmPrinter::doInitialization(M);
792 
793   GlobalsEmitted = false;
794 
795   return Result;
796 }
797 
798 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
799   SmallString<128> Str2;
800   raw_svector_ostream OS2(Str2);
801 
802   emitDeclarations(M, OS2);
803 
804   // As ptxas does not support forward references of globals, we need to first
805   // sort the list of module-level globals in def-use order. We visit each
806   // global variable in order, and ensure that we emit it *after* its dependent
807   // globals. We use a little extra memory maintaining both a set and a list to
808   // have fast searches while maintaining a strict ordering.
809   SmallVector<const GlobalVariable *, 8> Globals;
810   DenseSet<const GlobalVariable *> GVVisited;
811   DenseSet<const GlobalVariable *> GVVisiting;
812 
813   // Visit each global variable, in order
814   for (const GlobalVariable &I : M.globals())
815     VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
816 
817   assert(GVVisited.size() == M.getGlobalList().size() &&
818          "Missed a global variable");
819   assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
820 
821   // Print out module-level global variables in proper order
822   for (unsigned i = 0, e = Globals.size(); i != e; ++i)
823     printModuleLevelGV(Globals[i], OS2);
824 
825   OS2 << '\n';
826 
827   OutStreamer->emitRawText(OS2.str());
828 }
829 
830 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
831                                  const NVPTXSubtarget &STI) {
832   O << "//\n";
833   O << "// Generated by LLVM NVPTX Back-End\n";
834   O << "//\n";
835   O << "\n";
836 
837   unsigned PTXVersion = STI.getPTXVersion();
838   O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
839 
840   O << ".target ";
841   O << STI.getTargetName();
842 
843   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
844   if (NTM.getDrvInterface() == NVPTX::NVCL)
845     O << ", texmode_independent";
846 
847   bool HasFullDebugInfo = false;
848   for (DICompileUnit *CU : M.debug_compile_units()) {
849     switch(CU->getEmissionKind()) {
850     case DICompileUnit::NoDebug:
851     case DICompileUnit::DebugDirectivesOnly:
852       break;
853     case DICompileUnit::LineTablesOnly:
854     case DICompileUnit::FullDebug:
855       HasFullDebugInfo = true;
856       break;
857     }
858     if (HasFullDebugInfo)
859       break;
860   }
861   if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
862     O << ", debug";
863 
864   O << "\n";
865 
866   O << ".address_size ";
867   if (NTM.is64Bit())
868     O << "64";
869   else
870     O << "32";
871   O << "\n";
872 
873   O << "\n";
874 }
875 
876 bool NVPTXAsmPrinter::doFinalization(Module &M) {
877   bool HasDebugInfo = MMI && MMI->hasDebugInfo();
878 
879   // If we did not emit any functions, then the global declarations have not
880   // yet been emitted.
881   if (!GlobalsEmitted) {
882     emitGlobals(M);
883     GlobalsEmitted = true;
884   }
885 
886   // call doFinalization
887   bool ret = AsmPrinter::doFinalization(M);
888 
889   clearAnnotationCache(&M);
890 
891   // Close the last emitted section
892   if (HasDebugInfo) {
893     static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer())
894         ->closeLastSection();
895     // Emit empty .debug_loc section for better support of the empty files.
896     OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
897   }
898 
899   // Output last DWARF .file directives, if any.
900   static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer())
901       ->outputDwarfFileDirectives();
902 
903   return ret;
904 
905   //bool Result = AsmPrinter::doFinalization(M);
906   // Instead of calling the parents doFinalization, we may
907   // clone parents doFinalization and customize here.
908   // Currently, we if NVISA out the EmitGlobals() in
909   // parent's doFinalization, which is too intrusive.
910   //
911   // Same for the doInitialization.
912   //return Result;
913 }
914 
915 // This function emits appropriate linkage directives for
916 // functions and global variables.
917 //
918 // extern function declaration            -> .extern
919 // extern function definition             -> .visible
920 // external global variable with init     -> .visible
921 // external without init                  -> .extern
922 // appending                              -> not allowed, assert.
923 // for any linkage other than
924 // internal, private, linker_private,
925 // linker_private_weak, linker_private_weak_def_auto,
926 // we emit                                -> .weak.
927 
928 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
929                                            raw_ostream &O) {
930   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
931     if (V->hasExternalLinkage()) {
932       if (isa<GlobalVariable>(V)) {
933         const GlobalVariable *GVar = cast<GlobalVariable>(V);
934         if (GVar) {
935           if (GVar->hasInitializer())
936             O << ".visible ";
937           else
938             O << ".extern ";
939         }
940       } else if (V->isDeclaration())
941         O << ".extern ";
942       else
943         O << ".visible ";
944     } else if (V->hasAppendingLinkage()) {
945       std::string msg;
946       msg.append("Error: ");
947       msg.append("Symbol ");
948       if (V->hasName())
949         msg.append(std::string(V->getName()));
950       msg.append("has unsupported appending linkage type");
951       llvm_unreachable(msg.c_str());
952     } else if (!V->hasInternalLinkage() &&
953                !V->hasPrivateLinkage()) {
954       O << ".weak ";
955     }
956   }
957 }
958 
959 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
960                                          raw_ostream &O,
961                                          bool processDemoted) {
962   // Skip meta data
963   if (GVar->hasSection()) {
964     if (GVar->getSection() == "llvm.metadata")
965       return;
966   }
967 
968   // Skip LLVM intrinsic global variables
969   if (GVar->getName().startswith("llvm.") ||
970       GVar->getName().startswith("nvvm."))
971     return;
972 
973   const DataLayout &DL = getDataLayout();
974 
975   // GlobalVariables are always constant pointers themselves.
976   PointerType *PTy = GVar->getType();
977   Type *ETy = GVar->getValueType();
978 
979   if (GVar->hasExternalLinkage()) {
980     if (GVar->hasInitializer())
981       O << ".visible ";
982     else
983       O << ".extern ";
984   } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
985              GVar->hasAvailableExternallyLinkage() ||
986              GVar->hasCommonLinkage()) {
987     O << ".weak ";
988   }
989 
990   if (isTexture(*GVar)) {
991     O << ".global .texref " << getTextureName(*GVar) << ";\n";
992     return;
993   }
994 
995   if (isSurface(*GVar)) {
996     O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
997     return;
998   }
999 
1000   if (GVar->isDeclaration()) {
1001     // (extern) declarations, no definition or initializer
1002     // Currently the only known declaration is for an automatic __local
1003     // (.shared) promoted to global.
1004     emitPTXGlobalVariable(GVar, O);
1005     O << ";\n";
1006     return;
1007   }
1008 
1009   if (isSampler(*GVar)) {
1010     O << ".global .samplerref " << getSamplerName(*GVar);
1011 
1012     const Constant *Initializer = nullptr;
1013     if (GVar->hasInitializer())
1014       Initializer = GVar->getInitializer();
1015     const ConstantInt *CI = nullptr;
1016     if (Initializer)
1017       CI = dyn_cast<ConstantInt>(Initializer);
1018     if (CI) {
1019       unsigned sample = CI->getZExtValue();
1020 
1021       O << " = { ";
1022 
1023       for (int i = 0,
1024                addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1025            i < 3; i++) {
1026         O << "addr_mode_" << i << " = ";
1027         switch (addr) {
1028         case 0:
1029           O << "wrap";
1030           break;
1031         case 1:
1032           O << "clamp_to_border";
1033           break;
1034         case 2:
1035           O << "clamp_to_edge";
1036           break;
1037         case 3:
1038           O << "wrap";
1039           break;
1040         case 4:
1041           O << "mirror";
1042           break;
1043         }
1044         O << ", ";
1045       }
1046       O << "filter_mode = ";
1047       switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1048       case 0:
1049         O << "nearest";
1050         break;
1051       case 1:
1052         O << "linear";
1053         break;
1054       case 2:
1055         llvm_unreachable("Anisotropic filtering is not supported");
1056       default:
1057         O << "nearest";
1058         break;
1059       }
1060       if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1061         O << ", force_unnormalized_coords = 1";
1062       }
1063       O << " }";
1064     }
1065 
1066     O << ";\n";
1067     return;
1068   }
1069 
1070   if (GVar->hasPrivateLinkage()) {
1071     if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1072       return;
1073 
1074     // FIXME - need better way (e.g. Metadata) to avoid generating this global
1075     if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1076       return;
1077     if (GVar->use_empty())
1078       return;
1079   }
1080 
1081   const Function *demotedFunc = nullptr;
1082   if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1083     O << "// " << GVar->getName() << " has been demoted\n";
1084     if (localDecls.find(demotedFunc) != localDecls.end())
1085       localDecls[demotedFunc].push_back(GVar);
1086     else {
1087       std::vector<const GlobalVariable *> temp;
1088       temp.push_back(GVar);
1089       localDecls[demotedFunc] = temp;
1090     }
1091     return;
1092   }
1093 
1094   O << ".";
1095   emitPTXAddressSpace(PTy->getAddressSpace(), O);
1096 
1097   if (isManaged(*GVar)) {
1098     O << " .attribute(.managed)";
1099   }
1100 
1101   if (MaybeAlign A = GVar->getAlign())
1102     O << " .align " << A->value();
1103   else
1104     O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1105 
1106   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1107       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1108     O << " .";
1109     // Special case: ABI requires that we use .u8 for predicates
1110     if (ETy->isIntegerTy(1))
1111       O << "u8";
1112     else
1113       O << getPTXFundamentalTypeStr(ETy, false);
1114     O << " ";
1115     getSymbol(GVar)->print(O, MAI);
1116 
1117     // Ptx allows variable initilization only for constant and global state
1118     // spaces.
1119     if (GVar->hasInitializer()) {
1120       if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1121           (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1122         const Constant *Initializer = GVar->getInitializer();
1123         // 'undef' is treated as there is no value specified.
1124         if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1125           O << " = ";
1126           printScalarConstant(Initializer, O);
1127         }
1128       } else {
1129         // The frontend adds zero-initializer to device and constant variables
1130         // that don't have an initial value, and UndefValue to shared
1131         // variables, so skip warning for this case.
1132         if (!GVar->getInitializer()->isNullValue() &&
1133             !isa<UndefValue>(GVar->getInitializer())) {
1134           report_fatal_error("initial value of '" + GVar->getName() +
1135                              "' is not allowed in addrspace(" +
1136                              Twine(PTy->getAddressSpace()) + ")");
1137         }
1138       }
1139     }
1140   } else {
1141     unsigned int ElementSize = 0;
1142 
1143     // Although PTX has direct support for struct type and array type and
1144     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1145     // targets that support these high level field accesses. Structs, arrays
1146     // and vectors are lowered into arrays of bytes.
1147     switch (ETy->getTypeID()) {
1148     case Type::IntegerTyID: // Integers larger than 64 bits
1149     case Type::StructTyID:
1150     case Type::ArrayTyID:
1151     case Type::FixedVectorTyID:
1152       ElementSize = DL.getTypeStoreSize(ETy);
1153       // Ptx allows variable initilization only for constant and
1154       // global state spaces.
1155       if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1156            (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1157           GVar->hasInitializer()) {
1158         const Constant *Initializer = GVar->getInitializer();
1159         if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1160           AggBuffer aggBuffer(ElementSize, O, *this);
1161           bufferAggregateConstant(Initializer, &aggBuffer);
1162           if (aggBuffer.numSymbols) {
1163             if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit()) {
1164               O << " .u64 ";
1165               getSymbol(GVar)->print(O, MAI);
1166               O << "[";
1167               O << ElementSize / 8;
1168             } else {
1169               O << " .u32 ";
1170               getSymbol(GVar)->print(O, MAI);
1171               O << "[";
1172               O << ElementSize / 4;
1173             }
1174             O << "]";
1175           } else {
1176             O << " .b8 ";
1177             getSymbol(GVar)->print(O, MAI);
1178             O << "[";
1179             O << ElementSize;
1180             O << "]";
1181           }
1182           O << " = {";
1183           aggBuffer.print();
1184           O << "}";
1185         } else {
1186           O << " .b8 ";
1187           getSymbol(GVar)->print(O, MAI);
1188           if (ElementSize) {
1189             O << "[";
1190             O << ElementSize;
1191             O << "]";
1192           }
1193         }
1194       } else {
1195         O << " .b8 ";
1196         getSymbol(GVar)->print(O, MAI);
1197         if (ElementSize) {
1198           O << "[";
1199           O << ElementSize;
1200           O << "]";
1201         }
1202       }
1203       break;
1204     default:
1205       llvm_unreachable("type not supported yet");
1206     }
1207   }
1208   O << ";\n";
1209 }
1210 
1211 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1212   if (localDecls.find(f) == localDecls.end())
1213     return;
1214 
1215   std::vector<const GlobalVariable *> &gvars = localDecls[f];
1216 
1217   for (const GlobalVariable *GV : gvars) {
1218     O << "\t// demoted variable\n\t";
1219     printModuleLevelGV(GV, O, true);
1220   }
1221 }
1222 
1223 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1224                                           raw_ostream &O) const {
1225   switch (AddressSpace) {
1226   case ADDRESS_SPACE_LOCAL:
1227     O << "local";
1228     break;
1229   case ADDRESS_SPACE_GLOBAL:
1230     O << "global";
1231     break;
1232   case ADDRESS_SPACE_CONST:
1233     O << "const";
1234     break;
1235   case ADDRESS_SPACE_SHARED:
1236     O << "shared";
1237     break;
1238   default:
1239     report_fatal_error("Bad address space found while emitting PTX: " +
1240                        llvm::Twine(AddressSpace));
1241     break;
1242   }
1243 }
1244 
1245 std::string
1246 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1247   switch (Ty->getTypeID()) {
1248   case Type::IntegerTyID: {
1249     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1250     if (NumBits == 1)
1251       return "pred";
1252     else if (NumBits <= 64) {
1253       std::string name = "u";
1254       return name + utostr(NumBits);
1255     } else {
1256       llvm_unreachable("Integer too large");
1257       break;
1258     }
1259     break;
1260   }
1261   case Type::HalfTyID:
1262     // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
1263     return "b16";
1264   case Type::FloatTyID:
1265     return "f32";
1266   case Type::DoubleTyID:
1267     return "f64";
1268   case Type::PointerTyID:
1269     if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit())
1270       if (useB4PTR)
1271         return "b64";
1272       else
1273         return "u64";
1274     else if (useB4PTR)
1275       return "b32";
1276     else
1277       return "u32";
1278   default:
1279     break;
1280   }
1281   llvm_unreachable("unexpected type");
1282 }
1283 
1284 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1285                                             raw_ostream &O) {
1286   const DataLayout &DL = getDataLayout();
1287 
1288   // GlobalVariables are always constant pointers themselves.
1289   Type *ETy = GVar->getValueType();
1290 
1291   O << ".";
1292   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1293   if (MaybeAlign A = GVar->getAlign())
1294     O << " .align " << A->value();
1295   else
1296     O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1297 
1298   // Special case for i128
1299   if (ETy->isIntegerTy(128)) {
1300     O << " .b8 ";
1301     getSymbol(GVar)->print(O, MAI);
1302     O << "[16]";
1303     return;
1304   }
1305 
1306   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1307     O << " .";
1308     O << getPTXFundamentalTypeStr(ETy);
1309     O << " ";
1310     getSymbol(GVar)->print(O, MAI);
1311     return;
1312   }
1313 
1314   int64_t ElementSize = 0;
1315 
1316   // Although PTX has direct support for struct type and array type and LLVM IR
1317   // is very similar to PTX, the LLVM CodeGen does not support for targets that
1318   // support these high level field accesses. Structs and arrays are lowered
1319   // into arrays of bytes.
1320   switch (ETy->getTypeID()) {
1321   case Type::StructTyID:
1322   case Type::ArrayTyID:
1323   case Type::FixedVectorTyID:
1324     ElementSize = DL.getTypeStoreSize(ETy);
1325     O << " .b8 ";
1326     getSymbol(GVar)->print(O, MAI);
1327     O << "[";
1328     if (ElementSize) {
1329       O << ElementSize;
1330     }
1331     O << "]";
1332     break;
1333   default:
1334     llvm_unreachable("type not supported yet");
1335   }
1336 }
1337 
1338 static unsigned int getOpenCLAlignment(const DataLayout &DL, Type *Ty) {
1339   if (Ty->isSingleValueType())
1340     return DL.getPrefTypeAlignment(Ty);
1341 
1342   auto *ATy = dyn_cast<ArrayType>(Ty);
1343   if (ATy)
1344     return getOpenCLAlignment(DL, ATy->getElementType());
1345 
1346   auto *STy = dyn_cast<StructType>(Ty);
1347   if (STy) {
1348     unsigned int alignStruct = 1;
1349     // Go through each element of the struct and find the
1350     // largest alignment.
1351     for (unsigned i = 0, e = STy->getNumElements(); i != e; i++) {
1352       Type *ETy = STy->getElementType(i);
1353       unsigned int align = getOpenCLAlignment(DL, ETy);
1354       if (align > alignStruct)
1355         alignStruct = align;
1356     }
1357     return alignStruct;
1358   }
1359 
1360   auto *FTy = dyn_cast<FunctionType>(Ty);
1361   if (FTy)
1362     return DL.getPointerPrefAlignment().value();
1363   return DL.getPrefTypeAlignment(Ty);
1364 }
1365 
1366 void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
1367                                      int paramIndex, raw_ostream &O) {
1368   getSymbol(I->getParent())->print(O, MAI);
1369   O << "_param_" << paramIndex;
1370 }
1371 
1372 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1373   const DataLayout &DL = getDataLayout();
1374   const AttributeList &PAL = F->getAttributes();
1375   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1376   const TargetLowering *TLI = STI.getTargetLowering();
1377   Function::const_arg_iterator I, E;
1378   unsigned paramIndex = 0;
1379   bool first = true;
1380   bool isKernelFunc = isKernelFunction(*F);
1381   bool isABI = (STI.getSmVersion() >= 20);
1382   bool hasImageHandles = STI.hasImageHandles();
1383   MVT thePointerTy = TLI->getPointerTy(DL);
1384 
1385   if (F->arg_empty()) {
1386     O << "()\n";
1387     return;
1388   }
1389 
1390   O << "(\n";
1391 
1392   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1393     Type *Ty = I->getType();
1394 
1395     if (!first)
1396       O << ",\n";
1397 
1398     first = false;
1399 
1400     // Handle image/sampler parameters
1401     if (isKernelFunction(*F)) {
1402       if (isSampler(*I) || isImage(*I)) {
1403         if (isImage(*I)) {
1404           std::string sname = std::string(I->getName());
1405           if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1406             if (hasImageHandles)
1407               O << "\t.param .u64 .ptr .surfref ";
1408             else
1409               O << "\t.param .surfref ";
1410             CurrentFnSym->print(O, MAI);
1411             O << "_param_" << paramIndex;
1412           }
1413           else { // Default image is read_only
1414             if (hasImageHandles)
1415               O << "\t.param .u64 .ptr .texref ";
1416             else
1417               O << "\t.param .texref ";
1418             CurrentFnSym->print(O, MAI);
1419             O << "_param_" << paramIndex;
1420           }
1421         } else {
1422           if (hasImageHandles)
1423             O << "\t.param .u64 .ptr .samplerref ";
1424           else
1425             O << "\t.param .samplerref ";
1426           CurrentFnSym->print(O, MAI);
1427           O << "_param_" << paramIndex;
1428         }
1429         continue;
1430       }
1431     }
1432 
1433     if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1434       if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
1435         // Just print .param .align <a> .b8 .param[size];
1436         // <a> = PAL.getparamalignment
1437         // size = typeallocsize of element type
1438         const Align align = DL.getValueOrABITypeAlignment(
1439             PAL.getParamAlignment(paramIndex), Ty);
1440 
1441         unsigned sz = DL.getTypeAllocSize(Ty);
1442         O << "\t.param .align " << align.value() << " .b8 ";
1443         printParamName(I, paramIndex, O);
1444         O << "[" << sz << "]";
1445 
1446         continue;
1447       }
1448       // Just a scalar
1449       auto *PTy = dyn_cast<PointerType>(Ty);
1450       if (isKernelFunc) {
1451         if (PTy) {
1452           // Special handling for pointer arguments to kernel
1453           O << "\t.param .u" << thePointerTy.getSizeInBits() << " ";
1454 
1455           if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1456               NVPTX::CUDA) {
1457             Type *ETy = PTy->getPointerElementType();
1458             int addrSpace = PTy->getAddressSpace();
1459             switch (addrSpace) {
1460             default:
1461               O << ".ptr ";
1462               break;
1463             case ADDRESS_SPACE_CONST:
1464               O << ".ptr .const ";
1465               break;
1466             case ADDRESS_SPACE_SHARED:
1467               O << ".ptr .shared ";
1468               break;
1469             case ADDRESS_SPACE_GLOBAL:
1470               O << ".ptr .global ";
1471               break;
1472             }
1473             O << ".align " << (int)getOpenCLAlignment(DL, ETy) << " ";
1474           }
1475           printParamName(I, paramIndex, O);
1476           continue;
1477         }
1478 
1479         // non-pointer scalar to kernel func
1480         O << "\t.param .";
1481         // Special case: predicate operands become .u8 types
1482         if (Ty->isIntegerTy(1))
1483           O << "u8";
1484         else
1485           O << getPTXFundamentalTypeStr(Ty);
1486         O << " ";
1487         printParamName(I, paramIndex, O);
1488         continue;
1489       }
1490       // Non-kernel function, just print .param .b<size> for ABI
1491       // and .reg .b<size> for non-ABI
1492       unsigned sz = 0;
1493       if (isa<IntegerType>(Ty)) {
1494         sz = cast<IntegerType>(Ty)->getBitWidth();
1495         if (sz < 32)
1496           sz = 32;
1497       } else if (isa<PointerType>(Ty))
1498         sz = thePointerTy.getSizeInBits();
1499       else if (Ty->isHalfTy())
1500         // PTX ABI requires all scalar parameters to be at least 32
1501         // bits in size.  fp16 normally uses .b16 as its storage type
1502         // in PTX, so its size must be adjusted here, too.
1503         sz = 32;
1504       else
1505         sz = Ty->getPrimitiveSizeInBits();
1506       if (isABI)
1507         O << "\t.param .b" << sz << " ";
1508       else
1509         O << "\t.reg .b" << sz << " ";
1510       printParamName(I, paramIndex, O);
1511       continue;
1512     }
1513 
1514     // param has byVal attribute. So should be a pointer
1515     auto *PTy = dyn_cast<PointerType>(Ty);
1516     assert(PTy && "Param with byval attribute should be a pointer type");
1517     Type *ETy = PTy->getPointerElementType();
1518 
1519     if (isABI || isKernelFunc) {
1520       // Just print .param .align <a> .b8 .param[size];
1521       // <a> = PAL.getparamalignment
1522       // size = typeallocsize of element type
1523       Align align =
1524           DL.getValueOrABITypeAlignment(PAL.getParamAlignment(paramIndex), ETy);
1525       // Work around a bug in ptxas. When PTX code takes address of
1526       // byval parameter with alignment < 4, ptxas generates code to
1527       // spill argument into memory. Alas on sm_50+ ptxas generates
1528       // SASS code that fails with misaligned access. To work around
1529       // the problem, make sure that we align byval parameters by at
1530       // least 4. Matching change must be made in LowerCall() where we
1531       // prepare parameters for the call.
1532       //
1533       // TODO: this will need to be undone when we get to support multi-TU
1534       // device-side compilation as it breaks ABI compatibility with nvcc.
1535       // Hopefully ptxas bug is fixed by then.
1536       if (!isKernelFunc && align < Align(4))
1537         align = Align(4);
1538       unsigned sz = DL.getTypeAllocSize(ETy);
1539       O << "\t.param .align " << align.value() << " .b8 ";
1540       printParamName(I, paramIndex, O);
1541       O << "[" << sz << "]";
1542       continue;
1543     } else {
1544       // Split the ETy into constituent parts and
1545       // print .param .b<size> <name> for each part.
1546       // Further, if a part is vector, print the above for
1547       // each vector element.
1548       SmallVector<EVT, 16> vtparts;
1549       ComputeValueVTs(*TLI, DL, ETy, vtparts);
1550       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1551         unsigned elems = 1;
1552         EVT elemtype = vtparts[i];
1553         if (vtparts[i].isVector()) {
1554           elems = vtparts[i].getVectorNumElements();
1555           elemtype = vtparts[i].getVectorElementType();
1556         }
1557 
1558         for (unsigned j = 0, je = elems; j != je; ++j) {
1559           unsigned sz = elemtype.getSizeInBits();
1560           if (elemtype.isInteger() && (sz < 32))
1561             sz = 32;
1562           O << "\t.reg .b" << sz << " ";
1563           printParamName(I, paramIndex, O);
1564           if (j < je - 1)
1565             O << ",\n";
1566           ++paramIndex;
1567         }
1568         if (i < e - 1)
1569           O << ",\n";
1570       }
1571       --paramIndex;
1572       continue;
1573     }
1574   }
1575 
1576   O << "\n)\n";
1577 }
1578 
1579 void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF,
1580                                             raw_ostream &O) {
1581   const Function &F = MF.getFunction();
1582   emitFunctionParamList(&F, O);
1583 }
1584 
1585 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1586     const MachineFunction &MF) {
1587   SmallString<128> Str;
1588   raw_svector_ostream O(Str);
1589 
1590   // Map the global virtual register number to a register class specific
1591   // virtual register number starting from 1 with that class.
1592   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1593   //unsigned numRegClasses = TRI->getNumRegClasses();
1594 
1595   // Emit the Fake Stack Object
1596   const MachineFrameInfo &MFI = MF.getFrameInfo();
1597   int NumBytes = (int) MFI.getStackSize();
1598   if (NumBytes) {
1599     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1600       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1601     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1602       O << "\t.reg .b64 \t%SP;\n";
1603       O << "\t.reg .b64 \t%SPL;\n";
1604     } else {
1605       O << "\t.reg .b32 \t%SP;\n";
1606       O << "\t.reg .b32 \t%SPL;\n";
1607     }
1608   }
1609 
1610   // Go through all virtual registers to establish the mapping between the
1611   // global virtual
1612   // register number and the per class virtual register number.
1613   // We use the per class virtual register number in the ptx output.
1614   unsigned int numVRs = MRI->getNumVirtRegs();
1615   for (unsigned i = 0; i < numVRs; i++) {
1616     Register vr = Register::index2VirtReg(i);
1617     const TargetRegisterClass *RC = MRI->getRegClass(vr);
1618     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1619     int n = regmap.size();
1620     regmap.insert(std::make_pair(vr, n + 1));
1621   }
1622 
1623   // Emit register declarations
1624   // @TODO: Extract out the real register usage
1625   // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1626   // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1627   // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1628   // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1629   // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1630   // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1631   // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1632 
1633   // Emit declaration of the virtual registers or 'physical' registers for
1634   // each register class
1635   for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1636     const TargetRegisterClass *RC = TRI->getRegClass(i);
1637     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1638     std::string rcname = getNVPTXRegClassName(RC);
1639     std::string rcStr = getNVPTXRegClassStr(RC);
1640     int n = regmap.size();
1641 
1642     // Only declare those registers that may be used.
1643     if (n) {
1644        O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1645          << ">;\n";
1646     }
1647   }
1648 
1649   OutStreamer->emitRawText(O.str());
1650 }
1651 
1652 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1653   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1654   bool ignored;
1655   unsigned int numHex;
1656   const char *lead;
1657 
1658   if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1659     numHex = 8;
1660     lead = "0f";
1661     APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1662   } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1663     numHex = 16;
1664     lead = "0d";
1665     APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1666   } else
1667     llvm_unreachable("unsupported fp type");
1668 
1669   APInt API = APF.bitcastToAPInt();
1670   O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1671 }
1672 
1673 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1674   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1675     O << CI->getValue();
1676     return;
1677   }
1678   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1679     printFPConstant(CFP, O);
1680     return;
1681   }
1682   if (isa<ConstantPointerNull>(CPV)) {
1683     O << "0";
1684     return;
1685   }
1686   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1687     bool IsNonGenericPointer = false;
1688     if (GVar->getType()->getAddressSpace() != 0) {
1689       IsNonGenericPointer = true;
1690     }
1691     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1692       O << "generic(";
1693       getSymbol(GVar)->print(O, MAI);
1694       O << ")";
1695     } else {
1696       getSymbol(GVar)->print(O, MAI);
1697     }
1698     return;
1699   }
1700   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1701     const Value *v = Cexpr->stripPointerCasts();
1702     PointerType *PTy = dyn_cast<PointerType>(Cexpr->getType());
1703     bool IsNonGenericPointer = false;
1704     if (PTy && PTy->getAddressSpace() != 0) {
1705       IsNonGenericPointer = true;
1706     }
1707     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1708       if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) {
1709         O << "generic(";
1710         getSymbol(GVar)->print(O, MAI);
1711         O << ")";
1712       } else {
1713         getSymbol(GVar)->print(O, MAI);
1714       }
1715       return;
1716     } else {
1717       lowerConstant(CPV)->print(O, MAI);
1718       return;
1719     }
1720   }
1721   llvm_unreachable("Not scalar type found in printScalarConstant()");
1722 }
1723 
1724 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1725                                    AggBuffer *AggBuffer) {
1726   const DataLayout &DL = getDataLayout();
1727   int AllocSize = DL.getTypeAllocSize(CPV->getType());
1728   if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1729     // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1730     // only the space allocated by CPV.
1731     AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1732     return;
1733   }
1734 
1735   // Helper for filling AggBuffer with APInts.
1736   auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1737     size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1738     SmallVector<unsigned char, 16> Buf(NumBytes);
1739     for (unsigned I = 0; I < NumBytes; ++I) {
1740       Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1741     }
1742     AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1743   };
1744 
1745   switch (CPV->getType()->getTypeID()) {
1746   case Type::IntegerTyID:
1747     if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1748       AddIntToBuffer(CI->getValue());
1749       break;
1750     }
1751     if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1752       if (const auto *CI =
1753               dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1754         AddIntToBuffer(CI->getValue());
1755         break;
1756       }
1757       if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1758         Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1759         AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1760         AggBuffer->addZeros(AllocSize);
1761         break;
1762       }
1763     }
1764     llvm_unreachable("unsupported integer const type");
1765     break;
1766 
1767   case Type::HalfTyID:
1768   case Type::FloatTyID:
1769   case Type::DoubleTyID:
1770     AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1771     break;
1772 
1773   case Type::PointerTyID: {
1774     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1775       AggBuffer->addSymbol(GVar, GVar);
1776     } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1777       const Value *v = Cexpr->stripPointerCasts();
1778       AggBuffer->addSymbol(v, Cexpr);
1779     }
1780     AggBuffer->addZeros(AllocSize);
1781     break;
1782   }
1783 
1784   case Type::ArrayTyID:
1785   case Type::FixedVectorTyID:
1786   case Type::StructTyID: {
1787     if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1788       bufferAggregateConstant(CPV, AggBuffer);
1789       if (Bytes > AllocSize)
1790         AggBuffer->addZeros(Bytes - AllocSize);
1791     } else if (isa<ConstantAggregateZero>(CPV))
1792       AggBuffer->addZeros(Bytes);
1793     else
1794       llvm_unreachable("Unexpected Constant type");
1795     break;
1796   }
1797 
1798   default:
1799     llvm_unreachable("unsupported type");
1800   }
1801 }
1802 
1803 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1804                                               AggBuffer *aggBuffer) {
1805   const DataLayout &DL = getDataLayout();
1806   int Bytes;
1807 
1808   // Integers of arbitrary width
1809   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1810     APInt Val = CI->getValue();
1811     for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1812       uint8_t Byte = Val.getLoBits(8).getZExtValue();
1813       aggBuffer->addBytes(&Byte, 1, 1);
1814       Val.lshrInPlace(8);
1815     }
1816     return;
1817   }
1818 
1819   // Old constants
1820   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1821     if (CPV->getNumOperands())
1822       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1823         bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1824     return;
1825   }
1826 
1827   if (const ConstantDataSequential *CDS =
1828           dyn_cast<ConstantDataSequential>(CPV)) {
1829     if (CDS->getNumElements())
1830       for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1831         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1832                      aggBuffer);
1833     return;
1834   }
1835 
1836   if (isa<ConstantStruct>(CPV)) {
1837     if (CPV->getNumOperands()) {
1838       StructType *ST = cast<StructType>(CPV->getType());
1839       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1840         if (i == (e - 1))
1841           Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1842                   DL.getTypeAllocSize(ST) -
1843                   DL.getStructLayout(ST)->getElementOffset(i);
1844         else
1845           Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1846                   DL.getStructLayout(ST)->getElementOffset(i);
1847         bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1848       }
1849     }
1850     return;
1851   }
1852   llvm_unreachable("unsupported constant type in printAggregateConstant()");
1853 }
1854 
1855 /// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
1856 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1857 /// expressions that are representable in PTX and create
1858 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1859 const MCExpr *
1860 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1861   MCContext &Ctx = OutContext;
1862 
1863   if (CV->isNullValue() || isa<UndefValue>(CV))
1864     return MCConstantExpr::create(0, Ctx);
1865 
1866   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1867     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1868 
1869   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1870     const MCSymbolRefExpr *Expr =
1871       MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1872     if (ProcessingGeneric) {
1873       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1874     } else {
1875       return Expr;
1876     }
1877   }
1878 
1879   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1880   if (!CE) {
1881     llvm_unreachable("Unknown constant value to lower!");
1882   }
1883 
1884   switch (CE->getOpcode()) {
1885   default: {
1886     // If the code isn't optimized, there may be outstanding folding
1887     // opportunities. Attempt to fold the expression using DataLayout as a
1888     // last resort before giving up.
1889     Constant *C = ConstantFoldConstant(CE, getDataLayout());
1890     if (C != CE)
1891       return lowerConstantForGV(C, ProcessingGeneric);
1892 
1893     // Otherwise report the problem to the user.
1894     std::string S;
1895     raw_string_ostream OS(S);
1896     OS << "Unsupported expression in static initializer: ";
1897     CE->printAsOperand(OS, /*PrintType=*/false,
1898                    !MF ? nullptr : MF->getFunction().getParent());
1899     report_fatal_error(Twine(OS.str()));
1900   }
1901 
1902   case Instruction::AddrSpaceCast: {
1903     // Strip the addrspacecast and pass along the operand
1904     PointerType *DstTy = cast<PointerType>(CE->getType());
1905     if (DstTy->getAddressSpace() == 0) {
1906       return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
1907     }
1908     std::string S;
1909     raw_string_ostream OS(S);
1910     OS << "Unsupported expression in static initializer: ";
1911     CE->printAsOperand(OS, /*PrintType=*/ false,
1912                        !MF ? nullptr : MF->getFunction().getParent());
1913     report_fatal_error(Twine(OS.str()));
1914   }
1915 
1916   case Instruction::GetElementPtr: {
1917     const DataLayout &DL = getDataLayout();
1918 
1919     // Generate a symbolic expression for the byte address
1920     APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
1921     cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
1922 
1923     const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
1924                                             ProcessingGeneric);
1925     if (!OffsetAI)
1926       return Base;
1927 
1928     int64_t Offset = OffsetAI.getSExtValue();
1929     return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
1930                                    Ctx);
1931   }
1932 
1933   case Instruction::Trunc:
1934     // We emit the value and depend on the assembler to truncate the generated
1935     // expression properly.  This is important for differences between
1936     // blockaddress labels.  Since the two labels are in the same function, it
1937     // is reasonable to treat their delta as a 32-bit value.
1938     LLVM_FALLTHROUGH;
1939   case Instruction::BitCast:
1940     return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
1941 
1942   case Instruction::IntToPtr: {
1943     const DataLayout &DL = getDataLayout();
1944 
1945     // Handle casts to pointers by changing them into casts to the appropriate
1946     // integer type.  This promotes constant folding and simplifies this code.
1947     Constant *Op = CE->getOperand(0);
1948     Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
1949                                       false/*ZExt*/);
1950     return lowerConstantForGV(Op, ProcessingGeneric);
1951   }
1952 
1953   case Instruction::PtrToInt: {
1954     const DataLayout &DL = getDataLayout();
1955 
1956     // Support only foldable casts to/from pointers that can be eliminated by
1957     // changing the pointer to the appropriately sized integer type.
1958     Constant *Op = CE->getOperand(0);
1959     Type *Ty = CE->getType();
1960 
1961     const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
1962 
1963     // We can emit the pointer value into this slot if the slot is an
1964     // integer slot equal to the size of the pointer.
1965     if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
1966       return OpExpr;
1967 
1968     // Otherwise the pointer is smaller than the resultant integer, mask off
1969     // the high bits so we are sure to get a proper truncation if the input is
1970     // a constant expr.
1971     unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
1972     const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
1973     return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
1974   }
1975 
1976   // The MC library also has a right-shift operator, but it isn't consistently
1977   // signed or unsigned between different targets.
1978   case Instruction::Add: {
1979     const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
1980     const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
1981     switch (CE->getOpcode()) {
1982     default: llvm_unreachable("Unknown binary operator constant cast expr");
1983     case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
1984     }
1985   }
1986   }
1987 }
1988 
1989 // Copy of MCExpr::print customized for NVPTX
1990 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
1991   switch (Expr.getKind()) {
1992   case MCExpr::Target:
1993     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
1994   case MCExpr::Constant:
1995     OS << cast<MCConstantExpr>(Expr).getValue();
1996     return;
1997 
1998   case MCExpr::SymbolRef: {
1999     const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2000     const MCSymbol &Sym = SRE.getSymbol();
2001     Sym.print(OS, MAI);
2002     return;
2003   }
2004 
2005   case MCExpr::Unary: {
2006     const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2007     switch (UE.getOpcode()) {
2008     case MCUnaryExpr::LNot:  OS << '!'; break;
2009     case MCUnaryExpr::Minus: OS << '-'; break;
2010     case MCUnaryExpr::Not:   OS << '~'; break;
2011     case MCUnaryExpr::Plus:  OS << '+'; break;
2012     }
2013     printMCExpr(*UE.getSubExpr(), OS);
2014     return;
2015   }
2016 
2017   case MCExpr::Binary: {
2018     const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2019 
2020     // Only print parens around the LHS if it is non-trivial.
2021     if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2022         isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2023       printMCExpr(*BE.getLHS(), OS);
2024     } else {
2025       OS << '(';
2026       printMCExpr(*BE.getLHS(), OS);
2027       OS<< ')';
2028     }
2029 
2030     switch (BE.getOpcode()) {
2031     case MCBinaryExpr::Add:
2032       // Print "X-42" instead of "X+-42".
2033       if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2034         if (RHSC->getValue() < 0) {
2035           OS << RHSC->getValue();
2036           return;
2037         }
2038       }
2039 
2040       OS <<  '+';
2041       break;
2042     default: llvm_unreachable("Unhandled binary operator");
2043     }
2044 
2045     // Only print parens around the LHS if it is non-trivial.
2046     if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2047       printMCExpr(*BE.getRHS(), OS);
2048     } else {
2049       OS << '(';
2050       printMCExpr(*BE.getRHS(), OS);
2051       OS << ')';
2052     }
2053     return;
2054   }
2055   }
2056 
2057   llvm_unreachable("Invalid expression kind!");
2058 }
2059 
2060 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2061 ///
2062 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2063                                       const char *ExtraCode, raw_ostream &O) {
2064   if (ExtraCode && ExtraCode[0]) {
2065     if (ExtraCode[1] != 0)
2066       return true; // Unknown modifier.
2067 
2068     switch (ExtraCode[0]) {
2069     default:
2070       // See if this is a generic print operand
2071       return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2072     case 'r':
2073       break;
2074     }
2075   }
2076 
2077   printOperand(MI, OpNo, O);
2078 
2079   return false;
2080 }
2081 
2082 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2083                                             unsigned OpNo,
2084                                             const char *ExtraCode,
2085                                             raw_ostream &O) {
2086   if (ExtraCode && ExtraCode[0])
2087     return true; // Unknown modifier
2088 
2089   O << '[';
2090   printMemOperand(MI, OpNo, O);
2091   O << ']';
2092 
2093   return false;
2094 }
2095 
2096 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
2097                                    raw_ostream &O) {
2098   const MachineOperand &MO = MI->getOperand(opNum);
2099   switch (MO.getType()) {
2100   case MachineOperand::MO_Register:
2101     if (Register::isPhysicalRegister(MO.getReg())) {
2102       if (MO.getReg() == NVPTX::VRDepot)
2103         O << DEPOTNAME << getFunctionNumber();
2104       else
2105         O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2106     } else {
2107       emitVirtualRegister(MO.getReg(), O);
2108     }
2109     break;
2110 
2111   case MachineOperand::MO_Immediate:
2112     O << MO.getImm();
2113     break;
2114 
2115   case MachineOperand::MO_FPImmediate:
2116     printFPConstant(MO.getFPImm(), O);
2117     break;
2118 
2119   case MachineOperand::MO_GlobalAddress:
2120     PrintSymbolOperand(MO, O);
2121     break;
2122 
2123   case MachineOperand::MO_MachineBasicBlock:
2124     MO.getMBB()->getSymbol()->print(O, MAI);
2125     break;
2126 
2127   default:
2128     llvm_unreachable("Operand type not supported.");
2129   }
2130 }
2131 
2132 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
2133                                       raw_ostream &O, const char *Modifier) {
2134   printOperand(MI, opNum, O);
2135 
2136   if (Modifier && strcmp(Modifier, "add") == 0) {
2137     O << ", ";
2138     printOperand(MI, opNum + 1, O);
2139   } else {
2140     if (MI->getOperand(opNum + 1).isImm() &&
2141         MI->getOperand(opNum + 1).getImm() == 0)
2142       return; // don't print ',0' or '+0'
2143     O << "+";
2144     printOperand(MI, opNum + 1, O);
2145   }
2146 }
2147 
2148 // Force static initialization.
2149 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2150   RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2151   RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2152 }
2153