1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/MachineValueType.h"
48 #include "llvm/CodeGen/TargetRegisterInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalValue.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/LLVMContext.h"
64 #include "llvm/IR/Module.h"
65 #include "llvm/IR/Operator.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/MC/MCExpr.h"
69 #include "llvm/MC/MCInst.h"
70 #include "llvm/MC/MCInstrDesc.h"
71 #include "llvm/MC/MCStreamer.h"
72 #include "llvm/MC/MCSymbol.h"
73 #include "llvm/MC/TargetRegistry.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/Endian.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/NativeFormatting.h"
79 #include "llvm/Support/Path.h"
80 #include "llvm/Support/raw_ostream.h"
81 #include "llvm/Target/TargetLoweringObjectFile.h"
82 #include "llvm/Target/TargetMachine.h"
83 #include "llvm/TargetParser/Triple.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
85 #include <cassert>
86 #include <cstdint>
87 #include <cstring>
88 #include <new>
89 #include <string>
90 #include <utility>
91 #include <vector>
92
93 using namespace llvm;
94
95 static cl::opt<bool>
96 LowerCtorDtor("nvptx-lower-global-ctor-dtor",
97 cl::desc("Lower GPU ctor / dtors to globals on the device."),
98 cl::init(false), cl::Hidden);
99
100 #define DEPOTNAME "__local_depot"
101
102 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
103 /// depends.
104 static void
DiscoverDependentGlobals(const Value * V,DenseSet<const GlobalVariable * > & Globals)105 DiscoverDependentGlobals(const Value *V,
106 DenseSet<const GlobalVariable *> &Globals) {
107 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
108 Globals.insert(GV);
109 else {
110 if (const User *U = dyn_cast<User>(V)) {
111 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
112 DiscoverDependentGlobals(U->getOperand(i), Globals);
113 }
114 }
115 }
116 }
117
118 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
119 /// instances to be emitted, but only after any dependents have been added
120 /// first.s
121 static void
VisitGlobalVariableForEmission(const GlobalVariable * GV,SmallVectorImpl<const GlobalVariable * > & Order,DenseSet<const GlobalVariable * > & Visited,DenseSet<const GlobalVariable * > & Visiting)122 VisitGlobalVariableForEmission(const GlobalVariable *GV,
123 SmallVectorImpl<const GlobalVariable *> &Order,
124 DenseSet<const GlobalVariable *> &Visited,
125 DenseSet<const GlobalVariable *> &Visiting) {
126 // Have we already visited this one?
127 if (Visited.count(GV))
128 return;
129
130 // Do we have a circular dependency?
131 if (!Visiting.insert(GV).second)
132 report_fatal_error("Circular dependency found in global variable set");
133
134 // Make sure we visit all dependents first
135 DenseSet<const GlobalVariable *> Others;
136 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
137 DiscoverDependentGlobals(GV->getOperand(i), Others);
138
139 for (const GlobalVariable *GV : Others)
140 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
141
142 // Now we can visit ourself
143 Order.push_back(GV);
144 Visited.insert(GV);
145 Visiting.erase(GV);
146 }
147
emitInstruction(const MachineInstr * MI)148 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
149 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
150 getSubtargetInfo().getFeatureBits());
151
152 MCInst Inst;
153 lowerToMCInst(MI, Inst);
154 EmitToStreamer(*OutStreamer, Inst);
155 }
156
157 // Handle symbol backtracking for targets that do not support image handles
lowerImageHandleOperand(const MachineInstr * MI,unsigned OpNo,MCOperand & MCOp)158 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
159 unsigned OpNo, MCOperand &MCOp) {
160 const MachineOperand &MO = MI->getOperand(OpNo);
161 const MCInstrDesc &MCID = MI->getDesc();
162
163 if (MCID.TSFlags & NVPTXII::IsTexFlag) {
164 // This is a texture fetch, so operand 4 is a texref and operand 5 is
165 // a samplerref
166 if (OpNo == 4 && MO.isImm()) {
167 lowerImageHandleSymbol(MO.getImm(), MCOp);
168 return true;
169 }
170 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
171 lowerImageHandleSymbol(MO.getImm(), MCOp);
172 return true;
173 }
174
175 return false;
176 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
177 unsigned VecSize =
178 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
179
180 // For a surface load of vector size N, the Nth operand will be the surfref
181 if (OpNo == VecSize && MO.isImm()) {
182 lowerImageHandleSymbol(MO.getImm(), MCOp);
183 return true;
184 }
185
186 return false;
187 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
188 // This is a surface store, so operand 0 is a surfref
189 if (OpNo == 0 && MO.isImm()) {
190 lowerImageHandleSymbol(MO.getImm(), MCOp);
191 return true;
192 }
193
194 return false;
195 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
196 // This is a query, so operand 1 is a surfref/texref
197 if (OpNo == 1 && MO.isImm()) {
198 lowerImageHandleSymbol(MO.getImm(), MCOp);
199 return true;
200 }
201
202 return false;
203 }
204
205 return false;
206 }
207
lowerImageHandleSymbol(unsigned Index,MCOperand & MCOp)208 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
209 // Ewwww
210 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
211 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
212 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
213 const char *Sym = MFI->getImageHandleSymbol(Index);
214 StringRef SymName = nvTM.getStrPool().save(Sym);
215 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
216 }
217
lowerToMCInst(const MachineInstr * MI,MCInst & OutMI)218 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
219 OutMI.setOpcode(MI->getOpcode());
220 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
221 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
222 const MachineOperand &MO = MI->getOperand(0);
223 OutMI.addOperand(GetSymbolRef(
224 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
225 return;
226 }
227
228 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
229 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
230 const MachineOperand &MO = MI->getOperand(i);
231
232 MCOperand MCOp;
233 if (!STI.hasImageHandles()) {
234 if (lowerImageHandleOperand(MI, i, MCOp)) {
235 OutMI.addOperand(MCOp);
236 continue;
237 }
238 }
239
240 if (lowerOperand(MO, MCOp))
241 OutMI.addOperand(MCOp);
242 }
243 }
244
lowerOperand(const MachineOperand & MO,MCOperand & MCOp)245 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
246 MCOperand &MCOp) {
247 switch (MO.getType()) {
248 default: llvm_unreachable("unknown operand type");
249 case MachineOperand::MO_Register:
250 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
251 break;
252 case MachineOperand::MO_Immediate:
253 MCOp = MCOperand::createImm(MO.getImm());
254 break;
255 case MachineOperand::MO_MachineBasicBlock:
256 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
257 MO.getMBB()->getSymbol(), OutContext));
258 break;
259 case MachineOperand::MO_ExternalSymbol:
260 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
261 break;
262 case MachineOperand::MO_GlobalAddress:
263 MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
264 break;
265 case MachineOperand::MO_FPImmediate: {
266 const ConstantFP *Cnt = MO.getFPImm();
267 const APFloat &Val = Cnt->getValueAPF();
268
269 switch (Cnt->getType()->getTypeID()) {
270 default: report_fatal_error("Unsupported FP type"); break;
271 case Type::HalfTyID:
272 MCOp = MCOperand::createExpr(
273 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
274 break;
275 case Type::BFloatTyID:
276 MCOp = MCOperand::createExpr(
277 NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
278 break;
279 case Type::FloatTyID:
280 MCOp = MCOperand::createExpr(
281 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
282 break;
283 case Type::DoubleTyID:
284 MCOp = MCOperand::createExpr(
285 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
286 break;
287 }
288 break;
289 }
290 }
291 return true;
292 }
293
encodeVirtualRegister(unsigned Reg)294 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
295 if (Register::isVirtualRegister(Reg)) {
296 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
297
298 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
299 unsigned RegNum = RegMap[Reg];
300
301 // Encode the register class in the upper 4 bits
302 // Must be kept in sync with NVPTXInstPrinter::printRegName
303 unsigned Ret = 0;
304 if (RC == &NVPTX::Int1RegsRegClass) {
305 Ret = (1 << 28);
306 } else if (RC == &NVPTX::Int16RegsRegClass) {
307 Ret = (2 << 28);
308 } else if (RC == &NVPTX::Int32RegsRegClass) {
309 Ret = (3 << 28);
310 } else if (RC == &NVPTX::Int64RegsRegClass) {
311 Ret = (4 << 28);
312 } else if (RC == &NVPTX::Float32RegsRegClass) {
313 Ret = (5 << 28);
314 } else if (RC == &NVPTX::Float64RegsRegClass) {
315 Ret = (6 << 28);
316 } else {
317 report_fatal_error("Bad register class");
318 }
319
320 // Insert the vreg number
321 Ret |= (RegNum & 0x0FFFFFFF);
322 return Ret;
323 } else {
324 // Some special-use registers are actually physical registers.
325 // Encode this as the register class ID of 0 and the real register ID.
326 return Reg & 0x0FFFFFFF;
327 }
328 }
329
GetSymbolRef(const MCSymbol * Symbol)330 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
331 const MCExpr *Expr;
332 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
333 OutContext);
334 return MCOperand::createExpr(Expr);
335 }
336
ShouldPassAsArray(Type * Ty)337 static bool ShouldPassAsArray(Type *Ty) {
338 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
339 Ty->isHalfTy() || Ty->isBFloatTy();
340 }
341
printReturnValStr(const Function * F,raw_ostream & O)342 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
343 const DataLayout &DL = getDataLayout();
344 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
345 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
346
347 Type *Ty = F->getReturnType();
348
349 bool isABI = (STI.getSmVersion() >= 20);
350
351 if (Ty->getTypeID() == Type::VoidTyID)
352 return;
353 O << " (";
354
355 if (isABI) {
356 if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
357 !ShouldPassAsArray(Ty)) {
358 unsigned size = 0;
359 if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
360 size = ITy->getBitWidth();
361 } else {
362 assert(Ty->isFloatingPointTy() && "Floating point type expected here");
363 size = Ty->getPrimitiveSizeInBits();
364 }
365 size = promoteScalarArgumentSize(size);
366 O << ".param .b" << size << " func_retval0";
367 } else if (isa<PointerType>(Ty)) {
368 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
369 << " func_retval0";
370 } else if (ShouldPassAsArray(Ty)) {
371 unsigned totalsz = DL.getTypeAllocSize(Ty);
372 unsigned retAlignment = 0;
373 if (!getAlign(*F, 0, retAlignment))
374 retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
375 O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
376 << "]";
377 } else
378 llvm_unreachable("Unknown return type");
379 } else {
380 SmallVector<EVT, 16> vtparts;
381 ComputeValueVTs(*TLI, DL, Ty, vtparts);
382 unsigned idx = 0;
383 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
384 unsigned elems = 1;
385 EVT elemtype = vtparts[i];
386 if (vtparts[i].isVector()) {
387 elems = vtparts[i].getVectorNumElements();
388 elemtype = vtparts[i].getVectorElementType();
389 }
390
391 for (unsigned j = 0, je = elems; j != je; ++j) {
392 unsigned sz = elemtype.getSizeInBits();
393 if (elemtype.isInteger())
394 sz = promoteScalarArgumentSize(sz);
395 O << ".reg .b" << sz << " func_retval" << idx;
396 if (j < je - 1)
397 O << ", ";
398 ++idx;
399 }
400 if (i < e - 1)
401 O << ", ";
402 }
403 }
404 O << ") ";
405 }
406
printReturnValStr(const MachineFunction & MF,raw_ostream & O)407 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
408 raw_ostream &O) {
409 const Function &F = MF.getFunction();
410 printReturnValStr(&F, O);
411 }
412
413 // Return true if MBB is the header of a loop marked with
414 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
isLoopHeaderOfNoUnroll(const MachineBasicBlock & MBB) const415 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
416 const MachineBasicBlock &MBB) const {
417 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
418 // We insert .pragma "nounroll" only to the loop header.
419 if (!LI.isLoopHeader(&MBB))
420 return false;
421
422 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
423 // we iterate through each back edge of the loop with header MBB, and check
424 // whether its metadata contains llvm.loop.unroll.disable.
425 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
426 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
427 // Edges from other loops to MBB are not back edges.
428 continue;
429 }
430 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
431 if (MDNode *LoopID =
432 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
433 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
434 return true;
435 if (MDNode *UnrollCountMD =
436 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
437 if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
438 ->isOne())
439 return true;
440 }
441 }
442 }
443 }
444 return false;
445 }
446
emitBasicBlockStart(const MachineBasicBlock & MBB)447 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
448 AsmPrinter::emitBasicBlockStart(MBB);
449 if (isLoopHeaderOfNoUnroll(MBB))
450 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
451 }
452
emitFunctionEntryLabel()453 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
454 SmallString<128> Str;
455 raw_svector_ostream O(Str);
456
457 if (!GlobalsEmitted) {
458 emitGlobals(*MF->getFunction().getParent());
459 GlobalsEmitted = true;
460 }
461
462 // Set up
463 MRI = &MF->getRegInfo();
464 F = &MF->getFunction();
465 emitLinkageDirective(F, O);
466 if (isKernelFunction(*F))
467 O << ".entry ";
468 else {
469 O << ".func ";
470 printReturnValStr(*MF, O);
471 }
472
473 CurrentFnSym->print(O, MAI);
474
475 emitFunctionParamList(F, O);
476 O << "\n";
477
478 if (isKernelFunction(*F))
479 emitKernelFunctionDirectives(*F, O);
480
481 if (shouldEmitPTXNoReturn(F, TM))
482 O << ".noreturn";
483
484 OutStreamer->emitRawText(O.str());
485
486 VRegMapping.clear();
487 // Emit open brace for function body.
488 OutStreamer->emitRawText(StringRef("{\n"));
489 setAndEmitFunctionVirtualRegisters(*MF);
490 // Emit initial .loc debug directive for correct relocation symbol data.
491 if (MMI && MMI->hasDebugInfo())
492 emitInitialRawDwarfLocDirective(*MF);
493 }
494
runOnMachineFunction(MachineFunction & F)495 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
496 bool Result = AsmPrinter::runOnMachineFunction(F);
497 // Emit closing brace for the body of function F.
498 // The closing brace must be emitted here because we need to emit additional
499 // debug labels/data after the last basic block.
500 // We need to emit the closing brace here because we don't have function that
501 // finished emission of the function body.
502 OutStreamer->emitRawText(StringRef("}\n"));
503 return Result;
504 }
505
emitFunctionBodyStart()506 void NVPTXAsmPrinter::emitFunctionBodyStart() {
507 SmallString<128> Str;
508 raw_svector_ostream O(Str);
509 emitDemotedVars(&MF->getFunction(), O);
510 OutStreamer->emitRawText(O.str());
511 }
512
emitFunctionBodyEnd()513 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
514 VRegMapping.clear();
515 }
516
getFunctionFrameSymbol() const517 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
518 SmallString<128> Str;
519 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
520 return OutContext.getOrCreateSymbol(Str);
521 }
522
emitImplicitDef(const MachineInstr * MI) const523 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
524 Register RegNo = MI->getOperand(0).getReg();
525 if (RegNo.isVirtual()) {
526 OutStreamer->AddComment(Twine("implicit-def: ") +
527 getVirtualRegisterName(RegNo));
528 } else {
529 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
530 OutStreamer->AddComment(Twine("implicit-def: ") +
531 STI.getRegisterInfo()->getName(RegNo));
532 }
533 OutStreamer->addBlankLine();
534 }
535
emitKernelFunctionDirectives(const Function & F,raw_ostream & O) const536 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
537 raw_ostream &O) const {
538 // If the NVVM IR has some of reqntid* specified, then output
539 // the reqntid directive, and set the unspecified ones to 1.
540 // If none of Reqntid* is specified, don't output reqntid directive.
541 unsigned Reqntidx, Reqntidy, Reqntidz;
542 Reqntidx = Reqntidy = Reqntidz = 1;
543 bool ReqSpecified = false;
544 ReqSpecified |= getReqNTIDx(F, Reqntidx);
545 ReqSpecified |= getReqNTIDy(F, Reqntidy);
546 ReqSpecified |= getReqNTIDz(F, Reqntidz);
547
548 if (ReqSpecified)
549 O << ".reqntid " << Reqntidx << ", " << Reqntidy << ", " << Reqntidz
550 << "\n";
551
552 // If the NVVM IR has some of maxntid* specified, then output
553 // the maxntid directive, and set the unspecified ones to 1.
554 // If none of maxntid* is specified, don't output maxntid directive.
555 unsigned Maxntidx, Maxntidy, Maxntidz;
556 Maxntidx = Maxntidy = Maxntidz = 1;
557 bool MaxSpecified = false;
558 MaxSpecified |= getMaxNTIDx(F, Maxntidx);
559 MaxSpecified |= getMaxNTIDy(F, Maxntidy);
560 MaxSpecified |= getMaxNTIDz(F, Maxntidz);
561
562 if (MaxSpecified)
563 O << ".maxntid " << Maxntidx << ", " << Maxntidy << ", " << Maxntidz
564 << "\n";
565
566 unsigned Mincta = 0;
567 if (getMinCTASm(F, Mincta))
568 O << ".minnctapersm " << Mincta << "\n";
569
570 unsigned Maxnreg = 0;
571 if (getMaxNReg(F, Maxnreg))
572 O << ".maxnreg " << Maxnreg << "\n";
573
574 // .maxclusterrank directive requires SM_90 or higher, make sure that we
575 // filter it out for lower SM versions, as it causes a hard ptxas crash.
576 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
577 const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
578 unsigned Maxclusterrank = 0;
579 if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
580 O << ".maxclusterrank " << Maxclusterrank << "\n";
581 }
582
getVirtualRegisterName(unsigned Reg) const583 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
584 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
585
586 std::string Name;
587 raw_string_ostream NameStr(Name);
588
589 VRegRCMap::const_iterator I = VRegMapping.find(RC);
590 assert(I != VRegMapping.end() && "Bad register class");
591 const DenseMap<unsigned, unsigned> &RegMap = I->second;
592
593 VRegMap::const_iterator VI = RegMap.find(Reg);
594 assert(VI != RegMap.end() && "Bad virtual register");
595 unsigned MappedVR = VI->second;
596
597 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
598
599 NameStr.flush();
600 return Name;
601 }
602
emitVirtualRegister(unsigned int vr,raw_ostream & O)603 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
604 raw_ostream &O) {
605 O << getVirtualRegisterName(vr);
606 }
607
emitDeclaration(const Function * F,raw_ostream & O)608 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
609 emitLinkageDirective(F, O);
610 if (isKernelFunction(*F))
611 O << ".entry ";
612 else
613 O << ".func ";
614 printReturnValStr(F, O);
615 getSymbol(F)->print(O, MAI);
616 O << "\n";
617 emitFunctionParamList(F, O);
618 O << "\n";
619 if (shouldEmitPTXNoReturn(F, TM))
620 O << ".noreturn";
621 O << ";\n";
622 }
623
usedInGlobalVarDef(const Constant * C)624 static bool usedInGlobalVarDef(const Constant *C) {
625 if (!C)
626 return false;
627
628 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
629 return GV->getName() != "llvm.used";
630 }
631
632 for (const User *U : C->users())
633 if (const Constant *C = dyn_cast<Constant>(U))
634 if (usedInGlobalVarDef(C))
635 return true;
636
637 return false;
638 }
639
usedInOneFunc(const User * U,Function const * & oneFunc)640 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
641 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
642 if (othergv->getName() == "llvm.used")
643 return true;
644 }
645
646 if (const Instruction *instr = dyn_cast<Instruction>(U)) {
647 if (instr->getParent() && instr->getParent()->getParent()) {
648 const Function *curFunc = instr->getParent()->getParent();
649 if (oneFunc && (curFunc != oneFunc))
650 return false;
651 oneFunc = curFunc;
652 return true;
653 } else
654 return false;
655 }
656
657 for (const User *UU : U->users())
658 if (!usedInOneFunc(UU, oneFunc))
659 return false;
660
661 return true;
662 }
663
664 /* Find out if a global variable can be demoted to local scope.
665 * Currently, this is valid for CUDA shared variables, which have local
666 * scope and global lifetime. So the conditions to check are :
667 * 1. Is the global variable in shared address space?
668 * 2. Does it have local linkage?
669 * 3. Is the global variable referenced only in one function?
670 */
canDemoteGlobalVar(const GlobalVariable * gv,Function const * & f)671 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
672 if (!gv->hasLocalLinkage())
673 return false;
674 PointerType *Pty = gv->getType();
675 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
676 return false;
677
678 const Function *oneFunc = nullptr;
679
680 bool flag = usedInOneFunc(gv, oneFunc);
681 if (!flag)
682 return false;
683 if (!oneFunc)
684 return false;
685 f = oneFunc;
686 return true;
687 }
688
useFuncSeen(const Constant * C,DenseMap<const Function *,bool> & seenMap)689 static bool useFuncSeen(const Constant *C,
690 DenseMap<const Function *, bool> &seenMap) {
691 for (const User *U : C->users()) {
692 if (const Constant *cu = dyn_cast<Constant>(U)) {
693 if (useFuncSeen(cu, seenMap))
694 return true;
695 } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
696 const BasicBlock *bb = I->getParent();
697 if (!bb)
698 continue;
699 const Function *caller = bb->getParent();
700 if (!caller)
701 continue;
702 if (seenMap.contains(caller))
703 return true;
704 }
705 }
706 return false;
707 }
708
emitDeclarations(const Module & M,raw_ostream & O)709 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
710 DenseMap<const Function *, bool> seenMap;
711 for (const Function &F : M) {
712 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
713 emitDeclaration(&F, O);
714 continue;
715 }
716
717 if (F.isDeclaration()) {
718 if (F.use_empty())
719 continue;
720 if (F.getIntrinsicID())
721 continue;
722 emitDeclaration(&F, O);
723 continue;
724 }
725 for (const User *U : F.users()) {
726 if (const Constant *C = dyn_cast<Constant>(U)) {
727 if (usedInGlobalVarDef(C)) {
728 // The use is in the initialization of a global variable
729 // that is a function pointer, so print a declaration
730 // for the original function
731 emitDeclaration(&F, O);
732 break;
733 }
734 // Emit a declaration of this function if the function that
735 // uses this constant expr has already been seen.
736 if (useFuncSeen(C, seenMap)) {
737 emitDeclaration(&F, O);
738 break;
739 }
740 }
741
742 if (!isa<Instruction>(U))
743 continue;
744 const Instruction *instr = cast<Instruction>(U);
745 const BasicBlock *bb = instr->getParent();
746 if (!bb)
747 continue;
748 const Function *caller = bb->getParent();
749 if (!caller)
750 continue;
751
752 // If a caller has already been seen, then the caller is
753 // appearing in the module before the callee. so print out
754 // a declaration for the callee.
755 if (seenMap.contains(caller)) {
756 emitDeclaration(&F, O);
757 break;
758 }
759 }
760 seenMap[&F] = true;
761 }
762 }
763
isEmptyXXStructor(GlobalVariable * GV)764 static bool isEmptyXXStructor(GlobalVariable *GV) {
765 if (!GV) return true;
766 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
767 if (!InitList) return true; // Not an array; we don't know how to parse.
768 return InitList->getNumOperands() == 0;
769 }
770
emitStartOfAsmFile(Module & M)771 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
772 // Construct a default subtarget off of the TargetMachine defaults. The
773 // rest of NVPTX isn't friendly to change subtargets per function and
774 // so the default TargetMachine will have all of the options.
775 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
776 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
777 SmallString<128> Str1;
778 raw_svector_ostream OS1(Str1);
779
780 // Emit header before any dwarf directives are emitted below.
781 emitHeader(M, OS1, *STI);
782 OutStreamer->emitRawText(OS1.str());
783 }
784
doInitialization(Module & M)785 bool NVPTXAsmPrinter::doInitialization(Module &M) {
786 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
787 const NVPTXSubtarget &STI =
788 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
789 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
790 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
791
792 // OpenMP supports NVPTX global constructors and destructors.
793 bool IsOpenMP = M.getModuleFlag("openmp") != nullptr;
794
795 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
796 !LowerCtorDtor && !IsOpenMP) {
797 report_fatal_error(
798 "Module has a nontrivial global ctor, which NVPTX does not support.");
799 return true; // error
800 }
801 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&
802 !LowerCtorDtor && !IsOpenMP) {
803 report_fatal_error(
804 "Module has a nontrivial global dtor, which NVPTX does not support.");
805 return true; // error
806 }
807
808 // We need to call the parent's one explicitly.
809 bool Result = AsmPrinter::doInitialization(M);
810
811 GlobalsEmitted = false;
812
813 return Result;
814 }
815
emitGlobals(const Module & M)816 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
817 SmallString<128> Str2;
818 raw_svector_ostream OS2(Str2);
819
820 emitDeclarations(M, OS2);
821
822 // As ptxas does not support forward references of globals, we need to first
823 // sort the list of module-level globals in def-use order. We visit each
824 // global variable in order, and ensure that we emit it *after* its dependent
825 // globals. We use a little extra memory maintaining both a set and a list to
826 // have fast searches while maintaining a strict ordering.
827 SmallVector<const GlobalVariable *, 8> Globals;
828 DenseSet<const GlobalVariable *> GVVisited;
829 DenseSet<const GlobalVariable *> GVVisiting;
830
831 // Visit each global variable, in order
832 for (const GlobalVariable &I : M.globals())
833 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
834
835 assert(GVVisited.size() == M.global_size() && "Missed a global variable");
836 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
837
838 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
839 const NVPTXSubtarget &STI =
840 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
841
842 // Print out module-level global variables in proper order
843 for (unsigned i = 0, e = Globals.size(); i != e; ++i)
844 printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI);
845
846 OS2 << '\n';
847
848 OutStreamer->emitRawText(OS2.str());
849 }
850
emitGlobalAlias(const Module & M,const GlobalAlias & GA)851 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
852 SmallString<128> Str;
853 raw_svector_ostream OS(Str);
854
855 MCSymbol *Name = getSymbol(&GA);
856 const Function *F = dyn_cast<Function>(GA.getAliasee());
857 if (!F || isKernelFunction(*F))
858 report_fatal_error("NVPTX aliasee must be a non-kernel function");
859
860 if (GA.hasLinkOnceLinkage() || GA.hasWeakLinkage() ||
861 GA.hasAvailableExternallyLinkage() || GA.hasCommonLinkage())
862 report_fatal_error("NVPTX aliasee must not be '.weak'");
863
864 OS << "\n";
865 emitLinkageDirective(F, OS);
866 OS << ".func ";
867 printReturnValStr(F, OS);
868 OS << Name->getName();
869 emitFunctionParamList(F, OS);
870 if (shouldEmitPTXNoReturn(F, TM))
871 OS << "\n.noreturn";
872 OS << ";\n";
873
874 OS << ".alias " << Name->getName() << ", " << F->getName() << ";\n";
875
876 OutStreamer->emitRawText(OS.str());
877 }
878
emitHeader(Module & M,raw_ostream & O,const NVPTXSubtarget & STI)879 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
880 const NVPTXSubtarget &STI) {
881 O << "//\n";
882 O << "// Generated by LLVM NVPTX Back-End\n";
883 O << "//\n";
884 O << "\n";
885
886 unsigned PTXVersion = STI.getPTXVersion();
887 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
888
889 O << ".target ";
890 O << STI.getTargetName();
891
892 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
893 if (NTM.getDrvInterface() == NVPTX::NVCL)
894 O << ", texmode_independent";
895
896 bool HasFullDebugInfo = false;
897 for (DICompileUnit *CU : M.debug_compile_units()) {
898 switch(CU->getEmissionKind()) {
899 case DICompileUnit::NoDebug:
900 case DICompileUnit::DebugDirectivesOnly:
901 break;
902 case DICompileUnit::LineTablesOnly:
903 case DICompileUnit::FullDebug:
904 HasFullDebugInfo = true;
905 break;
906 }
907 if (HasFullDebugInfo)
908 break;
909 }
910 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
911 O << ", debug";
912
913 O << "\n";
914
915 O << ".address_size ";
916 if (NTM.is64Bit())
917 O << "64";
918 else
919 O << "32";
920 O << "\n";
921
922 O << "\n";
923 }
924
doFinalization(Module & M)925 bool NVPTXAsmPrinter::doFinalization(Module &M) {
926 bool HasDebugInfo = MMI && MMI->hasDebugInfo();
927
928 // If we did not emit any functions, then the global declarations have not
929 // yet been emitted.
930 if (!GlobalsEmitted) {
931 emitGlobals(M);
932 GlobalsEmitted = true;
933 }
934
935 // If we have any aliases we emit them at the end.
936 SmallVector<GlobalAlias *> AliasesToRemove;
937 for (GlobalAlias &Alias : M.aliases()) {
938 emitGlobalAlias(M, Alias);
939 AliasesToRemove.push_back(&Alias);
940 }
941
942 for (GlobalAlias *A : AliasesToRemove)
943 A->eraseFromParent();
944
945 // call doFinalization
946 bool ret = AsmPrinter::doFinalization(M);
947
948 clearAnnotationCache(&M);
949
950 auto *TS =
951 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
952 // Close the last emitted section
953 if (HasDebugInfo) {
954 TS->closeLastSection();
955 // Emit empty .debug_loc section for better support of the empty files.
956 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
957 }
958
959 // Output last DWARF .file directives, if any.
960 TS->outputDwarfFileDirectives();
961
962 return ret;
963 }
964
965 // This function emits appropriate linkage directives for
966 // functions and global variables.
967 //
968 // extern function declaration -> .extern
969 // extern function definition -> .visible
970 // external global variable with init -> .visible
971 // external without init -> .extern
972 // appending -> not allowed, assert.
973 // for any linkage other than
974 // internal, private, linker_private,
975 // linker_private_weak, linker_private_weak_def_auto,
976 // we emit -> .weak.
977
emitLinkageDirective(const GlobalValue * V,raw_ostream & O)978 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
979 raw_ostream &O) {
980 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
981 if (V->hasExternalLinkage()) {
982 if (isa<GlobalVariable>(V)) {
983 const GlobalVariable *GVar = cast<GlobalVariable>(V);
984 if (GVar) {
985 if (GVar->hasInitializer())
986 O << ".visible ";
987 else
988 O << ".extern ";
989 }
990 } else if (V->isDeclaration())
991 O << ".extern ";
992 else
993 O << ".visible ";
994 } else if (V->hasAppendingLinkage()) {
995 std::string msg;
996 msg.append("Error: ");
997 msg.append("Symbol ");
998 if (V->hasName())
999 msg.append(std::string(V->getName()));
1000 msg.append("has unsupported appending linkage type");
1001 llvm_unreachable(msg.c_str());
1002 } else if (!V->hasInternalLinkage() &&
1003 !V->hasPrivateLinkage()) {
1004 O << ".weak ";
1005 }
1006 }
1007 }
1008
printModuleLevelGV(const GlobalVariable * GVar,raw_ostream & O,bool processDemoted,const NVPTXSubtarget & STI)1009 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1010 raw_ostream &O, bool processDemoted,
1011 const NVPTXSubtarget &STI) {
1012 // Skip meta data
1013 if (GVar->hasSection()) {
1014 if (GVar->getSection() == "llvm.metadata")
1015 return;
1016 }
1017
1018 // Skip LLVM intrinsic global variables
1019 if (GVar->getName().starts_with("llvm.") ||
1020 GVar->getName().starts_with("nvvm."))
1021 return;
1022
1023 const DataLayout &DL = getDataLayout();
1024
1025 // GlobalVariables are always constant pointers themselves.
1026 PointerType *PTy = GVar->getType();
1027 Type *ETy = GVar->getValueType();
1028
1029 if (GVar->hasExternalLinkage()) {
1030 if (GVar->hasInitializer())
1031 O << ".visible ";
1032 else
1033 O << ".extern ";
1034 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1035 GVar->hasAvailableExternallyLinkage() ||
1036 GVar->hasCommonLinkage()) {
1037 O << ".weak ";
1038 }
1039
1040 if (isTexture(*GVar)) {
1041 O << ".global .texref " << getTextureName(*GVar) << ";\n";
1042 return;
1043 }
1044
1045 if (isSurface(*GVar)) {
1046 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1047 return;
1048 }
1049
1050 if (GVar->isDeclaration()) {
1051 // (extern) declarations, no definition or initializer
1052 // Currently the only known declaration is for an automatic __local
1053 // (.shared) promoted to global.
1054 emitPTXGlobalVariable(GVar, O, STI);
1055 O << ";\n";
1056 return;
1057 }
1058
1059 if (isSampler(*GVar)) {
1060 O << ".global .samplerref " << getSamplerName(*GVar);
1061
1062 const Constant *Initializer = nullptr;
1063 if (GVar->hasInitializer())
1064 Initializer = GVar->getInitializer();
1065 const ConstantInt *CI = nullptr;
1066 if (Initializer)
1067 CI = dyn_cast<ConstantInt>(Initializer);
1068 if (CI) {
1069 unsigned sample = CI->getZExtValue();
1070
1071 O << " = { ";
1072
1073 for (int i = 0,
1074 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1075 i < 3; i++) {
1076 O << "addr_mode_" << i << " = ";
1077 switch (addr) {
1078 case 0:
1079 O << "wrap";
1080 break;
1081 case 1:
1082 O << "clamp_to_border";
1083 break;
1084 case 2:
1085 O << "clamp_to_edge";
1086 break;
1087 case 3:
1088 O << "wrap";
1089 break;
1090 case 4:
1091 O << "mirror";
1092 break;
1093 }
1094 O << ", ";
1095 }
1096 O << "filter_mode = ";
1097 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1098 case 0:
1099 O << "nearest";
1100 break;
1101 case 1:
1102 O << "linear";
1103 break;
1104 case 2:
1105 llvm_unreachable("Anisotropic filtering is not supported");
1106 default:
1107 O << "nearest";
1108 break;
1109 }
1110 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1111 O << ", force_unnormalized_coords = 1";
1112 }
1113 O << " }";
1114 }
1115
1116 O << ";\n";
1117 return;
1118 }
1119
1120 if (GVar->hasPrivateLinkage()) {
1121 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1122 return;
1123
1124 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1125 if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1126 return;
1127 if (GVar->use_empty())
1128 return;
1129 }
1130
1131 const Function *demotedFunc = nullptr;
1132 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1133 O << "// " << GVar->getName() << " has been demoted\n";
1134 if (localDecls.find(demotedFunc) != localDecls.end())
1135 localDecls[demotedFunc].push_back(GVar);
1136 else {
1137 std::vector<const GlobalVariable *> temp;
1138 temp.push_back(GVar);
1139 localDecls[demotedFunc] = temp;
1140 }
1141 return;
1142 }
1143
1144 O << ".";
1145 emitPTXAddressSpace(PTy->getAddressSpace(), O);
1146
1147 if (isManaged(*GVar)) {
1148 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1149 report_fatal_error(
1150 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1151 }
1152 O << " .attribute(.managed)";
1153 }
1154
1155 if (MaybeAlign A = GVar->getAlign())
1156 O << " .align " << A->value();
1157 else
1158 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1159
1160 if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1161 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1162 O << " .";
1163 // Special case: ABI requires that we use .u8 for predicates
1164 if (ETy->isIntegerTy(1))
1165 O << "u8";
1166 else
1167 O << getPTXFundamentalTypeStr(ETy, false);
1168 O << " ";
1169 getSymbol(GVar)->print(O, MAI);
1170
1171 // Ptx allows variable initilization only for constant and global state
1172 // spaces.
1173 if (GVar->hasInitializer()) {
1174 if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1175 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1176 const Constant *Initializer = GVar->getInitializer();
1177 // 'undef' is treated as there is no value specified.
1178 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1179 O << " = ";
1180 printScalarConstant(Initializer, O);
1181 }
1182 } else {
1183 // The frontend adds zero-initializer to device and constant variables
1184 // that don't have an initial value, and UndefValue to shared
1185 // variables, so skip warning for this case.
1186 if (!GVar->getInitializer()->isNullValue() &&
1187 !isa<UndefValue>(GVar->getInitializer())) {
1188 report_fatal_error("initial value of '" + GVar->getName() +
1189 "' is not allowed in addrspace(" +
1190 Twine(PTy->getAddressSpace()) + ")");
1191 }
1192 }
1193 }
1194 } else {
1195 uint64_t ElementSize = 0;
1196
1197 // Although PTX has direct support for struct type and array type and
1198 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1199 // targets that support these high level field accesses. Structs, arrays
1200 // and vectors are lowered into arrays of bytes.
1201 switch (ETy->getTypeID()) {
1202 case Type::IntegerTyID: // Integers larger than 64 bits
1203 case Type::StructTyID:
1204 case Type::ArrayTyID:
1205 case Type::FixedVectorTyID:
1206 ElementSize = DL.getTypeStoreSize(ETy);
1207 // Ptx allows variable initilization only for constant and
1208 // global state spaces.
1209 if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1210 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1211 GVar->hasInitializer()) {
1212 const Constant *Initializer = GVar->getInitializer();
1213 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1214 AggBuffer aggBuffer(ElementSize, *this);
1215 bufferAggregateConstant(Initializer, &aggBuffer);
1216 if (aggBuffer.numSymbols()) {
1217 unsigned int ptrSize = MAI->getCodePointerSize();
1218 if (ElementSize % ptrSize ||
1219 !aggBuffer.allSymbolsAligned(ptrSize)) {
1220 // Print in bytes and use the mask() operator for pointers.
1221 if (!STI.hasMaskOperator())
1222 report_fatal_error(
1223 "initialized packed aggregate with pointers '" +
1224 GVar->getName() +
1225 "' requires at least PTX ISA version 7.1");
1226 O << " .u8 ";
1227 getSymbol(GVar)->print(O, MAI);
1228 O << "[" << ElementSize << "] = {";
1229 aggBuffer.printBytes(O);
1230 O << "}";
1231 } else {
1232 O << " .u" << ptrSize * 8 << " ";
1233 getSymbol(GVar)->print(O, MAI);
1234 O << "[" << ElementSize / ptrSize << "] = {";
1235 aggBuffer.printWords(O);
1236 O << "}";
1237 }
1238 } else {
1239 O << " .b8 ";
1240 getSymbol(GVar)->print(O, MAI);
1241 O << "[" << ElementSize << "] = {";
1242 aggBuffer.printBytes(O);
1243 O << "}";
1244 }
1245 } else {
1246 O << " .b8 ";
1247 getSymbol(GVar)->print(O, MAI);
1248 if (ElementSize) {
1249 O << "[";
1250 O << ElementSize;
1251 O << "]";
1252 }
1253 }
1254 } else {
1255 O << " .b8 ";
1256 getSymbol(GVar)->print(O, MAI);
1257 if (ElementSize) {
1258 O << "[";
1259 O << ElementSize;
1260 O << "]";
1261 }
1262 }
1263 break;
1264 default:
1265 llvm_unreachable("type not supported yet");
1266 }
1267 }
1268 O << ";\n";
1269 }
1270
printSymbol(unsigned nSym,raw_ostream & os)1271 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1272 const Value *v = Symbols[nSym];
1273 const Value *v0 = SymbolsBeforeStripping[nSym];
1274 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1275 MCSymbol *Name = AP.getSymbol(GVar);
1276 PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1277 // Is v0 a generic pointer?
1278 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1279 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1280 os << "generic(";
1281 Name->print(os, AP.MAI);
1282 os << ")";
1283 } else {
1284 Name->print(os, AP.MAI);
1285 }
1286 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1287 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1288 AP.printMCExpr(*Expr, os);
1289 } else
1290 llvm_unreachable("symbol type unknown");
1291 }
1292
printBytes(raw_ostream & os)1293 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1294 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1295 symbolPosInBuffer.push_back(size);
1296 unsigned int nSym = 0;
1297 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1298 for (unsigned int pos = 0; pos < size;) {
1299 if (pos)
1300 os << ", ";
1301 if (pos != nextSymbolPos) {
1302 os << (unsigned int)buffer[pos];
1303 ++pos;
1304 continue;
1305 }
1306 // Generate a per-byte mask() operator for the symbol, which looks like:
1307 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1308 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1309 std::string symText;
1310 llvm::raw_string_ostream oss(symText);
1311 printSymbol(nSym, oss);
1312 for (unsigned i = 0; i < ptrSize; ++i) {
1313 if (i)
1314 os << ", ";
1315 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1316 os << "(" << symText << ")";
1317 }
1318 pos += ptrSize;
1319 nextSymbolPos = symbolPosInBuffer[++nSym];
1320 assert(nextSymbolPos >= pos);
1321 }
1322 }
1323
printWords(raw_ostream & os)1324 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1325 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1326 symbolPosInBuffer.push_back(size);
1327 unsigned int nSym = 0;
1328 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1329 assert(nextSymbolPos % ptrSize == 0);
1330 for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1331 if (pos)
1332 os << ", ";
1333 if (pos == nextSymbolPos) {
1334 printSymbol(nSym, os);
1335 nextSymbolPos = symbolPosInBuffer[++nSym];
1336 assert(nextSymbolPos % ptrSize == 0);
1337 assert(nextSymbolPos >= pos + ptrSize);
1338 } else if (ptrSize == 4)
1339 os << support::endian::read32le(&buffer[pos]);
1340 else
1341 os << support::endian::read64le(&buffer[pos]);
1342 }
1343 }
1344
emitDemotedVars(const Function * f,raw_ostream & O)1345 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1346 if (localDecls.find(f) == localDecls.end())
1347 return;
1348
1349 std::vector<const GlobalVariable *> &gvars = localDecls[f];
1350
1351 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1352 const NVPTXSubtarget &STI =
1353 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1354
1355 for (const GlobalVariable *GV : gvars) {
1356 O << "\t// demoted variable\n\t";
1357 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1358 }
1359 }
1360
emitPTXAddressSpace(unsigned int AddressSpace,raw_ostream & O) const1361 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1362 raw_ostream &O) const {
1363 switch (AddressSpace) {
1364 case ADDRESS_SPACE_LOCAL:
1365 O << "local";
1366 break;
1367 case ADDRESS_SPACE_GLOBAL:
1368 O << "global";
1369 break;
1370 case ADDRESS_SPACE_CONST:
1371 O << "const";
1372 break;
1373 case ADDRESS_SPACE_SHARED:
1374 O << "shared";
1375 break;
1376 default:
1377 report_fatal_error("Bad address space found while emitting PTX: " +
1378 llvm::Twine(AddressSpace));
1379 break;
1380 }
1381 }
1382
1383 std::string
getPTXFundamentalTypeStr(Type * Ty,bool useB4PTR) const1384 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1385 switch (Ty->getTypeID()) {
1386 case Type::IntegerTyID: {
1387 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1388 if (NumBits == 1)
1389 return "pred";
1390 else if (NumBits <= 64) {
1391 std::string name = "u";
1392 return name + utostr(NumBits);
1393 } else {
1394 llvm_unreachable("Integer too large");
1395 break;
1396 }
1397 break;
1398 }
1399 case Type::BFloatTyID:
1400 case Type::HalfTyID:
1401 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1402 // PTX assembly.
1403 return "b16";
1404 case Type::FloatTyID:
1405 return "f32";
1406 case Type::DoubleTyID:
1407 return "f64";
1408 case Type::PointerTyID: {
1409 unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1410 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1411
1412 if (PtrSize == 64)
1413 if (useB4PTR)
1414 return "b64";
1415 else
1416 return "u64";
1417 else if (useB4PTR)
1418 return "b32";
1419 else
1420 return "u32";
1421 }
1422 default:
1423 break;
1424 }
1425 llvm_unreachable("unexpected type");
1426 }
1427
emitPTXGlobalVariable(const GlobalVariable * GVar,raw_ostream & O,const NVPTXSubtarget & STI)1428 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1429 raw_ostream &O,
1430 const NVPTXSubtarget &STI) {
1431 const DataLayout &DL = getDataLayout();
1432
1433 // GlobalVariables are always constant pointers themselves.
1434 Type *ETy = GVar->getValueType();
1435
1436 O << ".";
1437 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1438 if (isManaged(*GVar)) {
1439 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1440 report_fatal_error(
1441 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1442 }
1443 O << " .attribute(.managed)";
1444 }
1445 if (MaybeAlign A = GVar->getAlign())
1446 O << " .align " << A->value();
1447 else
1448 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1449
1450 // Special case for i128
1451 if (ETy->isIntegerTy(128)) {
1452 O << " .b8 ";
1453 getSymbol(GVar)->print(O, MAI);
1454 O << "[16]";
1455 return;
1456 }
1457
1458 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1459 O << " .";
1460 O << getPTXFundamentalTypeStr(ETy);
1461 O << " ";
1462 getSymbol(GVar)->print(O, MAI);
1463 return;
1464 }
1465
1466 int64_t ElementSize = 0;
1467
1468 // Although PTX has direct support for struct type and array type and LLVM IR
1469 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1470 // support these high level field accesses. Structs and arrays are lowered
1471 // into arrays of bytes.
1472 switch (ETy->getTypeID()) {
1473 case Type::StructTyID:
1474 case Type::ArrayTyID:
1475 case Type::FixedVectorTyID:
1476 ElementSize = DL.getTypeStoreSize(ETy);
1477 O << " .b8 ";
1478 getSymbol(GVar)->print(O, MAI);
1479 O << "[";
1480 if (ElementSize) {
1481 O << ElementSize;
1482 }
1483 O << "]";
1484 break;
1485 default:
1486 llvm_unreachable("type not supported yet");
1487 }
1488 }
1489
emitFunctionParamList(const Function * F,raw_ostream & O)1490 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1491 const DataLayout &DL = getDataLayout();
1492 const AttributeList &PAL = F->getAttributes();
1493 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1494 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1495
1496 Function::const_arg_iterator I, E;
1497 unsigned paramIndex = 0;
1498 bool first = true;
1499 bool isKernelFunc = isKernelFunction(*F);
1500 bool isABI = (STI.getSmVersion() >= 20);
1501 bool hasImageHandles = STI.hasImageHandles();
1502
1503 if (F->arg_empty() && !F->isVarArg()) {
1504 O << "()";
1505 return;
1506 }
1507
1508 O << "(\n";
1509
1510 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1511 Type *Ty = I->getType();
1512
1513 if (!first)
1514 O << ",\n";
1515
1516 first = false;
1517
1518 // Handle image/sampler parameters
1519 if (isKernelFunction(*F)) {
1520 if (isSampler(*I) || isImage(*I)) {
1521 if (isImage(*I)) {
1522 std::string sname = std::string(I->getName());
1523 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1524 if (hasImageHandles)
1525 O << "\t.param .u64 .ptr .surfref ";
1526 else
1527 O << "\t.param .surfref ";
1528 O << TLI->getParamName(F, paramIndex);
1529 }
1530 else { // Default image is read_only
1531 if (hasImageHandles)
1532 O << "\t.param .u64 .ptr .texref ";
1533 else
1534 O << "\t.param .texref ";
1535 O << TLI->getParamName(F, paramIndex);
1536 }
1537 } else {
1538 if (hasImageHandles)
1539 O << "\t.param .u64 .ptr .samplerref ";
1540 else
1541 O << "\t.param .samplerref ";
1542 O << TLI->getParamName(F, paramIndex);
1543 }
1544 continue;
1545 }
1546 }
1547
1548 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1549 paramIndex](Type *Ty) -> Align {
1550 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1551 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1552 return std::max(TypeAlign, ParamAlign.valueOrOne());
1553 };
1554
1555 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1556 if (ShouldPassAsArray(Ty)) {
1557 // Just print .param .align <a> .b8 .param[size];
1558 // <a> = optimal alignment for the element type; always multiple of
1559 // PAL.getParamAlignment
1560 // size = typeallocsize of element type
1561 Align OptimalAlign = getOptimalAlignForParam(Ty);
1562
1563 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1564 O << TLI->getParamName(F, paramIndex);
1565 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1566
1567 continue;
1568 }
1569 // Just a scalar
1570 auto *PTy = dyn_cast<PointerType>(Ty);
1571 unsigned PTySizeInBits = 0;
1572 if (PTy) {
1573 PTySizeInBits =
1574 TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1575 assert(PTySizeInBits && "Invalid pointer size");
1576 }
1577
1578 if (isKernelFunc) {
1579 if (PTy) {
1580 // Special handling for pointer arguments to kernel
1581 O << "\t.param .u" << PTySizeInBits << " ";
1582
1583 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1584 NVPTX::CUDA) {
1585 int addrSpace = PTy->getAddressSpace();
1586 switch (addrSpace) {
1587 default:
1588 O << ".ptr ";
1589 break;
1590 case ADDRESS_SPACE_CONST:
1591 O << ".ptr .const ";
1592 break;
1593 case ADDRESS_SPACE_SHARED:
1594 O << ".ptr .shared ";
1595 break;
1596 case ADDRESS_SPACE_GLOBAL:
1597 O << ".ptr .global ";
1598 break;
1599 }
1600 Align ParamAlign = I->getParamAlign().valueOrOne();
1601 O << ".align " << ParamAlign.value() << " ";
1602 }
1603 O << TLI->getParamName(F, paramIndex);
1604 continue;
1605 }
1606
1607 // non-pointer scalar to kernel func
1608 O << "\t.param .";
1609 // Special case: predicate operands become .u8 types
1610 if (Ty->isIntegerTy(1))
1611 O << "u8";
1612 else
1613 O << getPTXFundamentalTypeStr(Ty);
1614 O << " ";
1615 O << TLI->getParamName(F, paramIndex);
1616 continue;
1617 }
1618 // Non-kernel function, just print .param .b<size> for ABI
1619 // and .reg .b<size> for non-ABI
1620 unsigned sz = 0;
1621 if (isa<IntegerType>(Ty)) {
1622 sz = cast<IntegerType>(Ty)->getBitWidth();
1623 sz = promoteScalarArgumentSize(sz);
1624 } else if (PTy) {
1625 assert(PTySizeInBits && "Invalid pointer size");
1626 sz = PTySizeInBits;
1627 } else
1628 sz = Ty->getPrimitiveSizeInBits();
1629 if (isABI)
1630 O << "\t.param .b" << sz << " ";
1631 else
1632 O << "\t.reg .b" << sz << " ";
1633 O << TLI->getParamName(F, paramIndex);
1634 continue;
1635 }
1636
1637 // param has byVal attribute.
1638 Type *ETy = PAL.getParamByValType(paramIndex);
1639 assert(ETy && "Param should have byval type");
1640
1641 if (isABI || isKernelFunc) {
1642 // Just print .param .align <a> .b8 .param[size];
1643 // <a> = optimal alignment for the element type; always multiple of
1644 // PAL.getParamAlignment
1645 // size = typeallocsize of element type
1646 Align OptimalAlign =
1647 isKernelFunc
1648 ? getOptimalAlignForParam(ETy)
1649 : TLI->getFunctionByValParamAlign(
1650 F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1651
1652 unsigned sz = DL.getTypeAllocSize(ETy);
1653 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1654 O << TLI->getParamName(F, paramIndex);
1655 O << "[" << sz << "]";
1656 continue;
1657 } else {
1658 // Split the ETy into constituent parts and
1659 // print .param .b<size> <name> for each part.
1660 // Further, if a part is vector, print the above for
1661 // each vector element.
1662 SmallVector<EVT, 16> vtparts;
1663 ComputeValueVTs(*TLI, DL, ETy, vtparts);
1664 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1665 unsigned elems = 1;
1666 EVT elemtype = vtparts[i];
1667 if (vtparts[i].isVector()) {
1668 elems = vtparts[i].getVectorNumElements();
1669 elemtype = vtparts[i].getVectorElementType();
1670 }
1671
1672 for (unsigned j = 0, je = elems; j != je; ++j) {
1673 unsigned sz = elemtype.getSizeInBits();
1674 if (elemtype.isInteger())
1675 sz = promoteScalarArgumentSize(sz);
1676 O << "\t.reg .b" << sz << " ";
1677 O << TLI->getParamName(F, paramIndex);
1678 if (j < je - 1)
1679 O << ",\n";
1680 ++paramIndex;
1681 }
1682 if (i < e - 1)
1683 O << ",\n";
1684 }
1685 --paramIndex;
1686 continue;
1687 }
1688 }
1689
1690 if (F->isVarArg()) {
1691 if (!first)
1692 O << ",\n";
1693 O << "\t.param .align " << STI.getMaxRequiredAlignment();
1694 O << " .b8 ";
1695 O << TLI->getParamName(F, /* vararg */ -1) << "[]";
1696 }
1697
1698 O << "\n)";
1699 }
1700
setAndEmitFunctionVirtualRegisters(const MachineFunction & MF)1701 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1702 const MachineFunction &MF) {
1703 SmallString<128> Str;
1704 raw_svector_ostream O(Str);
1705
1706 // Map the global virtual register number to a register class specific
1707 // virtual register number starting from 1 with that class.
1708 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1709 //unsigned numRegClasses = TRI->getNumRegClasses();
1710
1711 // Emit the Fake Stack Object
1712 const MachineFrameInfo &MFI = MF.getFrameInfo();
1713 int NumBytes = (int) MFI.getStackSize();
1714 if (NumBytes) {
1715 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1716 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1717 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1718 O << "\t.reg .b64 \t%SP;\n";
1719 O << "\t.reg .b64 \t%SPL;\n";
1720 } else {
1721 O << "\t.reg .b32 \t%SP;\n";
1722 O << "\t.reg .b32 \t%SPL;\n";
1723 }
1724 }
1725
1726 // Go through all virtual registers to establish the mapping between the
1727 // global virtual
1728 // register number and the per class virtual register number.
1729 // We use the per class virtual register number in the ptx output.
1730 unsigned int numVRs = MRI->getNumVirtRegs();
1731 for (unsigned i = 0; i < numVRs; i++) {
1732 Register vr = Register::index2VirtReg(i);
1733 const TargetRegisterClass *RC = MRI->getRegClass(vr);
1734 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];
1735 int n = regmap.size();
1736 regmap.insert(std::make_pair(vr, n + 1));
1737 }
1738
1739 // Emit register declarations
1740 // @TODO: Extract out the real register usage
1741 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1742 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1743 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1744 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1745 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1746 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1747 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1748
1749 // Emit declaration of the virtual registers or 'physical' registers for
1750 // each register class
1751 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1752 const TargetRegisterClass *RC = TRI->getRegClass(i);
1753 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];
1754 std::string rcname = getNVPTXRegClassName(RC);
1755 std::string rcStr = getNVPTXRegClassStr(RC);
1756 int n = regmap.size();
1757
1758 // Only declare those registers that may be used.
1759 if (n) {
1760 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1761 << ">;\n";
1762 }
1763 }
1764
1765 OutStreamer->emitRawText(O.str());
1766 }
1767
printFPConstant(const ConstantFP * Fp,raw_ostream & O)1768 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1769 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1770 bool ignored;
1771 unsigned int numHex;
1772 const char *lead;
1773
1774 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1775 numHex = 8;
1776 lead = "0f";
1777 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1778 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1779 numHex = 16;
1780 lead = "0d";
1781 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1782 } else
1783 llvm_unreachable("unsupported fp type");
1784
1785 APInt API = APF.bitcastToAPInt();
1786 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1787 }
1788
printScalarConstant(const Constant * CPV,raw_ostream & O)1789 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1790 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1791 O << CI->getValue();
1792 return;
1793 }
1794 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1795 printFPConstant(CFP, O);
1796 return;
1797 }
1798 if (isa<ConstantPointerNull>(CPV)) {
1799 O << "0";
1800 return;
1801 }
1802 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1803 bool IsNonGenericPointer = false;
1804 if (GVar->getType()->getAddressSpace() != 0) {
1805 IsNonGenericPointer = true;
1806 }
1807 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1808 O << "generic(";
1809 getSymbol(GVar)->print(O, MAI);
1810 O << ")";
1811 } else {
1812 getSymbol(GVar)->print(O, MAI);
1813 }
1814 return;
1815 }
1816 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1817 const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1818 printMCExpr(*E, O);
1819 return;
1820 }
1821 llvm_unreachable("Not scalar type found in printScalarConstant()");
1822 }
1823
bufferLEByte(const Constant * CPV,int Bytes,AggBuffer * AggBuffer)1824 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1825 AggBuffer *AggBuffer) {
1826 const DataLayout &DL = getDataLayout();
1827 int AllocSize = DL.getTypeAllocSize(CPV->getType());
1828 if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1829 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1830 // only the space allocated by CPV.
1831 AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1832 return;
1833 }
1834
1835 // Helper for filling AggBuffer with APInts.
1836 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1837 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1838 SmallVector<unsigned char, 16> Buf(NumBytes);
1839 for (unsigned I = 0; I < NumBytes; ++I) {
1840 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1841 }
1842 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1843 };
1844
1845 switch (CPV->getType()->getTypeID()) {
1846 case Type::IntegerTyID:
1847 if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1848 AddIntToBuffer(CI->getValue());
1849 break;
1850 }
1851 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1852 if (const auto *CI =
1853 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1854 AddIntToBuffer(CI->getValue());
1855 break;
1856 }
1857 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1858 Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1859 AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1860 AggBuffer->addZeros(AllocSize);
1861 break;
1862 }
1863 }
1864 llvm_unreachable("unsupported integer const type");
1865 break;
1866
1867 case Type::HalfTyID:
1868 case Type::BFloatTyID:
1869 case Type::FloatTyID:
1870 case Type::DoubleTyID:
1871 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1872 break;
1873
1874 case Type::PointerTyID: {
1875 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1876 AggBuffer->addSymbol(GVar, GVar);
1877 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1878 const Value *v = Cexpr->stripPointerCasts();
1879 AggBuffer->addSymbol(v, Cexpr);
1880 }
1881 AggBuffer->addZeros(AllocSize);
1882 break;
1883 }
1884
1885 case Type::ArrayTyID:
1886 case Type::FixedVectorTyID:
1887 case Type::StructTyID: {
1888 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1889 bufferAggregateConstant(CPV, AggBuffer);
1890 if (Bytes > AllocSize)
1891 AggBuffer->addZeros(Bytes - AllocSize);
1892 } else if (isa<ConstantAggregateZero>(CPV))
1893 AggBuffer->addZeros(Bytes);
1894 else
1895 llvm_unreachable("Unexpected Constant type");
1896 break;
1897 }
1898
1899 default:
1900 llvm_unreachable("unsupported type");
1901 }
1902 }
1903
bufferAggregateConstant(const Constant * CPV,AggBuffer * aggBuffer)1904 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1905 AggBuffer *aggBuffer) {
1906 const DataLayout &DL = getDataLayout();
1907 int Bytes;
1908
1909 // Integers of arbitrary width
1910 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1911 APInt Val = CI->getValue();
1912 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1913 uint8_t Byte = Val.getLoBits(8).getZExtValue();
1914 aggBuffer->addBytes(&Byte, 1, 1);
1915 Val.lshrInPlace(8);
1916 }
1917 return;
1918 }
1919
1920 // Old constants
1921 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1922 if (CPV->getNumOperands())
1923 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1924 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1925 return;
1926 }
1927
1928 if (const ConstantDataSequential *CDS =
1929 dyn_cast<ConstantDataSequential>(CPV)) {
1930 if (CDS->getNumElements())
1931 for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1932 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1933 aggBuffer);
1934 return;
1935 }
1936
1937 if (isa<ConstantStruct>(CPV)) {
1938 if (CPV->getNumOperands()) {
1939 StructType *ST = cast<StructType>(CPV->getType());
1940 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1941 if (i == (e - 1))
1942 Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1943 DL.getTypeAllocSize(ST) -
1944 DL.getStructLayout(ST)->getElementOffset(i);
1945 else
1946 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1947 DL.getStructLayout(ST)->getElementOffset(i);
1948 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1949 }
1950 }
1951 return;
1952 }
1953 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1954 }
1955
1956 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1957 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1958 /// expressions that are representable in PTX and create
1959 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1960 const MCExpr *
lowerConstantForGV(const Constant * CV,bool ProcessingGeneric)1961 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1962 MCContext &Ctx = OutContext;
1963
1964 if (CV->isNullValue() || isa<UndefValue>(CV))
1965 return MCConstantExpr::create(0, Ctx);
1966
1967 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1968 return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1969
1970 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1971 const MCSymbolRefExpr *Expr =
1972 MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1973 if (ProcessingGeneric) {
1974 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1975 } else {
1976 return Expr;
1977 }
1978 }
1979
1980 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1981 if (!CE) {
1982 llvm_unreachable("Unknown constant value to lower!");
1983 }
1984
1985 switch (CE->getOpcode()) {
1986 default:
1987 break; // Error
1988
1989 case Instruction::AddrSpaceCast: {
1990 // Strip the addrspacecast and pass along the operand
1991 PointerType *DstTy = cast<PointerType>(CE->getType());
1992 if (DstTy->getAddressSpace() == 0)
1993 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
1994
1995 break; // Error
1996 }
1997
1998 case Instruction::GetElementPtr: {
1999 const DataLayout &DL = getDataLayout();
2000
2001 // Generate a symbolic expression for the byte address
2002 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2003 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2004
2005 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2006 ProcessingGeneric);
2007 if (!OffsetAI)
2008 return Base;
2009
2010 int64_t Offset = OffsetAI.getSExtValue();
2011 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2012 Ctx);
2013 }
2014
2015 case Instruction::Trunc:
2016 // We emit the value and depend on the assembler to truncate the generated
2017 // expression properly. This is important for differences between
2018 // blockaddress labels. Since the two labels are in the same function, it
2019 // is reasonable to treat their delta as a 32-bit value.
2020 [[fallthrough]];
2021 case Instruction::BitCast:
2022 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2023
2024 case Instruction::IntToPtr: {
2025 const DataLayout &DL = getDataLayout();
2026
2027 // Handle casts to pointers by changing them into casts to the appropriate
2028 // integer type. This promotes constant folding and simplifies this code.
2029 Constant *Op = CE->getOperand(0);
2030 Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2031 /*IsSigned*/ false, DL);
2032 if (Op)
2033 return lowerConstantForGV(Op, ProcessingGeneric);
2034
2035 break; // Error
2036 }
2037
2038 case Instruction::PtrToInt: {
2039 const DataLayout &DL = getDataLayout();
2040
2041 // Support only foldable casts to/from pointers that can be eliminated by
2042 // changing the pointer to the appropriately sized integer type.
2043 Constant *Op = CE->getOperand(0);
2044 Type *Ty = CE->getType();
2045
2046 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2047
2048 // We can emit the pointer value into this slot if the slot is an
2049 // integer slot equal to the size of the pointer.
2050 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2051 return OpExpr;
2052
2053 // Otherwise the pointer is smaller than the resultant integer, mask off
2054 // the high bits so we are sure to get a proper truncation if the input is
2055 // a constant expr.
2056 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2057 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2058 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2059 }
2060
2061 // The MC library also has a right-shift operator, but it isn't consistently
2062 // signed or unsigned between different targets.
2063 case Instruction::Add: {
2064 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2065 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2066 switch (CE->getOpcode()) {
2067 default: llvm_unreachable("Unknown binary operator constant cast expr");
2068 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2069 }
2070 }
2071 }
2072
2073 // If the code isn't optimized, there may be outstanding folding
2074 // opportunities. Attempt to fold the expression using DataLayout as a
2075 // last resort before giving up.
2076 Constant *C = ConstantFoldConstant(CE, getDataLayout());
2077 if (C != CE)
2078 return lowerConstantForGV(C, ProcessingGeneric);
2079
2080 // Otherwise report the problem to the user.
2081 std::string S;
2082 raw_string_ostream OS(S);
2083 OS << "Unsupported expression in static initializer: ";
2084 CE->printAsOperand(OS, /*PrintType=*/false,
2085 !MF ? nullptr : MF->getFunction().getParent());
2086 report_fatal_error(Twine(OS.str()));
2087 }
2088
2089 // Copy of MCExpr::print customized for NVPTX
printMCExpr(const MCExpr & Expr,raw_ostream & OS)2090 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2091 switch (Expr.getKind()) {
2092 case MCExpr::Target:
2093 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2094 case MCExpr::Constant:
2095 OS << cast<MCConstantExpr>(Expr).getValue();
2096 return;
2097
2098 case MCExpr::SymbolRef: {
2099 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2100 const MCSymbol &Sym = SRE.getSymbol();
2101 Sym.print(OS, MAI);
2102 return;
2103 }
2104
2105 case MCExpr::Unary: {
2106 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2107 switch (UE.getOpcode()) {
2108 case MCUnaryExpr::LNot: OS << '!'; break;
2109 case MCUnaryExpr::Minus: OS << '-'; break;
2110 case MCUnaryExpr::Not: OS << '~'; break;
2111 case MCUnaryExpr::Plus: OS << '+'; break;
2112 }
2113 printMCExpr(*UE.getSubExpr(), OS);
2114 return;
2115 }
2116
2117 case MCExpr::Binary: {
2118 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2119
2120 // Only print parens around the LHS if it is non-trivial.
2121 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2122 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2123 printMCExpr(*BE.getLHS(), OS);
2124 } else {
2125 OS << '(';
2126 printMCExpr(*BE.getLHS(), OS);
2127 OS<< ')';
2128 }
2129
2130 switch (BE.getOpcode()) {
2131 case MCBinaryExpr::Add:
2132 // Print "X-42" instead of "X+-42".
2133 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2134 if (RHSC->getValue() < 0) {
2135 OS << RHSC->getValue();
2136 return;
2137 }
2138 }
2139
2140 OS << '+';
2141 break;
2142 default: llvm_unreachable("Unhandled binary operator");
2143 }
2144
2145 // Only print parens around the LHS if it is non-trivial.
2146 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2147 printMCExpr(*BE.getRHS(), OS);
2148 } else {
2149 OS << '(';
2150 printMCExpr(*BE.getRHS(), OS);
2151 OS << ')';
2152 }
2153 return;
2154 }
2155 }
2156
2157 llvm_unreachable("Invalid expression kind!");
2158 }
2159
2160 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2161 ///
PrintAsmOperand(const MachineInstr * MI,unsigned OpNo,const char * ExtraCode,raw_ostream & O)2162 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2163 const char *ExtraCode, raw_ostream &O) {
2164 if (ExtraCode && ExtraCode[0]) {
2165 if (ExtraCode[1] != 0)
2166 return true; // Unknown modifier.
2167
2168 switch (ExtraCode[0]) {
2169 default:
2170 // See if this is a generic print operand
2171 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2172 case 'r':
2173 break;
2174 }
2175 }
2176
2177 printOperand(MI, OpNo, O);
2178
2179 return false;
2180 }
2181
PrintAsmMemoryOperand(const MachineInstr * MI,unsigned OpNo,const char * ExtraCode,raw_ostream & O)2182 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2183 unsigned OpNo,
2184 const char *ExtraCode,
2185 raw_ostream &O) {
2186 if (ExtraCode && ExtraCode[0])
2187 return true; // Unknown modifier
2188
2189 O << '[';
2190 printMemOperand(MI, OpNo, O);
2191 O << ']';
2192
2193 return false;
2194 }
2195
printOperand(const MachineInstr * MI,unsigned OpNum,raw_ostream & O)2196 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
2197 raw_ostream &O) {
2198 const MachineOperand &MO = MI->getOperand(OpNum);
2199 switch (MO.getType()) {
2200 case MachineOperand::MO_Register:
2201 if (MO.getReg().isPhysical()) {
2202 if (MO.getReg() == NVPTX::VRDepot)
2203 O << DEPOTNAME << getFunctionNumber();
2204 else
2205 O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2206 } else {
2207 emitVirtualRegister(MO.getReg(), O);
2208 }
2209 break;
2210
2211 case MachineOperand::MO_Immediate:
2212 O << MO.getImm();
2213 break;
2214
2215 case MachineOperand::MO_FPImmediate:
2216 printFPConstant(MO.getFPImm(), O);
2217 break;
2218
2219 case MachineOperand::MO_GlobalAddress:
2220 PrintSymbolOperand(MO, O);
2221 break;
2222
2223 case MachineOperand::MO_MachineBasicBlock:
2224 MO.getMBB()->getSymbol()->print(O, MAI);
2225 break;
2226
2227 default:
2228 llvm_unreachable("Operand type not supported.");
2229 }
2230 }
2231
printMemOperand(const MachineInstr * MI,unsigned OpNum,raw_ostream & O,const char * Modifier)2232 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
2233 raw_ostream &O, const char *Modifier) {
2234 printOperand(MI, OpNum, O);
2235
2236 if (Modifier && strcmp(Modifier, "add") == 0) {
2237 O << ", ";
2238 printOperand(MI, OpNum + 1, O);
2239 } else {
2240 if (MI->getOperand(OpNum + 1).isImm() &&
2241 MI->getOperand(OpNum + 1).getImm() == 0)
2242 return; // don't print ',0' or '+0'
2243 O << "+";
2244 printOperand(MI, OpNum + 1, O);
2245 }
2246 }
2247
2248 // Force static initialization.
LLVMInitializeNVPTXAsmPrinter()2249 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2250 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2251 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2252 }
2253