1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This provides an abstract class for HLSL code generation. Concrete
10 // subclasses of this implement code generation for specific HLSL
11 // runtime libraries.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "CGHLSLRuntime.h"
16 #include "CGDebugInfo.h"
17 #include "CodeGenModule.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/Basic/TargetOptions.h"
20 #include "llvm/IR/IntrinsicsDirectX.h"
21 #include "llvm/IR/Metadata.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/Support/FormatVariadic.h"
24
25 using namespace clang;
26 using namespace CodeGen;
27 using namespace clang::hlsl;
28 using namespace llvm;
29
30 namespace {
31
addDxilValVersion(StringRef ValVersionStr,llvm::Module & M)32 void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
33 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
34 // Assume ValVersionStr is legal here.
35 VersionTuple Version;
36 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
37 Version.getSubminor() || !Version.getMinor()) {
38 return;
39 }
40
41 uint64_t Major = Version.getMajor();
42 uint64_t Minor = *Version.getMinor();
43
44 auto &Ctx = M.getContext();
45 IRBuilder<> B(M.getContext());
46 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
47 ConstantAsMetadata::get(B.getInt32(Minor))});
48 StringRef DXILValKey = "dx.valver";
49 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
50 DXILValMD->addOperand(Val);
51 }
addDisableOptimizations(llvm::Module & M)52 void addDisableOptimizations(llvm::Module &M) {
53 StringRef Key = "dx.disable_optimizations";
54 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
55 }
56 // cbuffer will be translated into global variable in special address space.
57 // If translate into C,
58 // cbuffer A {
59 // float a;
60 // float b;
61 // }
62 // float foo() { return a + b; }
63 //
64 // will be translated into
65 //
66 // struct A {
67 // float a;
68 // float b;
69 // } cbuffer_A __attribute__((address_space(4)));
70 // float foo() { return cbuffer_A.a + cbuffer_A.b; }
71 //
72 // layoutBuffer will create the struct A type.
73 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
74 // and cbuffer_A.b.
75 //
layoutBuffer(CGHLSLRuntime::Buffer & Buf,const DataLayout & DL)76 void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
77 if (Buf.Constants.empty())
78 return;
79
80 std::vector<llvm::Type *> EltTys;
81 for (auto &Const : Buf.Constants) {
82 GlobalVariable *GV = Const.first;
83 Const.second = EltTys.size();
84 llvm::Type *Ty = GV->getValueType();
85 EltTys.emplace_back(Ty);
86 }
87 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
88 }
89
replaceBuffer(CGHLSLRuntime::Buffer & Buf)90 GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
91 // Create global variable for CB.
92 GlobalVariable *CBGV = new GlobalVariable(
93 Buf.LayoutStruct, /*isConstant*/ true,
94 GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
95 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
96 GlobalValue::NotThreadLocal);
97
98 IRBuilder<> B(CBGV->getContext());
99 Value *ZeroIdx = B.getInt32(0);
100 // Replace Const use with CB use.
101 for (auto &[GV, Offset] : Buf.Constants) {
102 Value *GEP =
103 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
104
105 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
106 "constant type mismatch");
107
108 // Replace.
109 GV->replaceAllUsesWith(GEP);
110 // Erase GV.
111 GV->removeDeadConstantUsers();
112 GV->eraseFromParent();
113 }
114 return CBGV;
115 }
116
117 } // namespace
118
addConstant(VarDecl * D,Buffer & CB)119 void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
120 if (D->getStorageClass() == SC_Static) {
121 // For static inside cbuffer, take as global static.
122 // Don't add to cbuffer.
123 CGM.EmitGlobal(D);
124 return;
125 }
126
127 auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
128 // Add debug info for constVal.
129 if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
130 if (CGM.getCodeGenOpts().getDebugInfo() >=
131 codegenoptions::DebugInfoKind::LimitedDebugInfo)
132 DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
133
134 // FIXME: support packoffset.
135 // See https://github.com/llvm/llvm-project/issues/57914.
136 uint32_t Offset = 0;
137 bool HasUserOffset = false;
138
139 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
140 CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
141 }
142
addBufferDecls(const DeclContext * DC,Buffer & CB)143 void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
144 for (Decl *it : DC->decls()) {
145 if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
146 addConstant(ConstDecl, CB);
147 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
148 // Nothing to do for this declaration.
149 } else if (isa<FunctionDecl>(it)) {
150 // A function within an cbuffer is effectively a top-level function,
151 // as it only refers to globally scoped declarations.
152 CGM.EmitTopLevelDecl(it);
153 }
154 }
155 }
156
addBuffer(const HLSLBufferDecl * D)157 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
158 Buffers.emplace_back(Buffer(D));
159 addBufferDecls(D, Buffers.back());
160 }
161
finishCodeGen()162 void CGHLSLRuntime::finishCodeGen() {
163 auto &TargetOpts = CGM.getTarget().getTargetOpts();
164 llvm::Module &M = CGM.getModule();
165 Triple T(M.getTargetTriple());
166 if (T.getArch() == Triple::ArchType::dxil)
167 addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
168
169 generateGlobalCtorDtorCalls();
170 if (CGM.getCodeGenOpts().OptimizationLevel == 0)
171 addDisableOptimizations(M);
172
173 const DataLayout &DL = M.getDataLayout();
174
175 for (auto &Buf : Buffers) {
176 layoutBuffer(Buf, DL);
177 GlobalVariable *GV = replaceBuffer(Buf);
178 M.insertGlobalVariable(GV);
179 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
180 ? llvm::hlsl::ResourceClass::CBuffer
181 : llvm::hlsl::ResourceClass::SRV;
182 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
183 ? llvm::hlsl::ResourceKind::CBuffer
184 : llvm::hlsl::ResourceKind::TBuffer;
185 addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
186 llvm::hlsl::ElementType::Invalid, Buf.Binding);
187 }
188 }
189
Buffer(const HLSLBufferDecl * D)190 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
191 : Name(D->getName()), IsCBuffer(D->isCBuffer()),
192 Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
193
addBufferResourceAnnotation(llvm::GlobalVariable * GV,llvm::hlsl::ResourceClass RC,llvm::hlsl::ResourceKind RK,bool IsROV,llvm::hlsl::ElementType ET,BufferResBinding & Binding)194 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
195 llvm::hlsl::ResourceClass RC,
196 llvm::hlsl::ResourceKind RK,
197 bool IsROV,
198 llvm::hlsl::ElementType ET,
199 BufferResBinding &Binding) {
200 llvm::Module &M = CGM.getModule();
201
202 NamedMDNode *ResourceMD = nullptr;
203 switch (RC) {
204 case llvm::hlsl::ResourceClass::UAV:
205 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
206 break;
207 case llvm::hlsl::ResourceClass::SRV:
208 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
209 break;
210 case llvm::hlsl::ResourceClass::CBuffer:
211 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
212 break;
213 default:
214 assert(false && "Unsupported buffer type!");
215 return;
216 }
217 assert(ResourceMD != nullptr &&
218 "ResourceMD must have been set by the switch above.");
219
220 llvm::hlsl::FrontendResource Res(
221 GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
222 ResourceMD->addOperand(Res.getMetadata());
223 }
224
225 static llvm::hlsl::ElementType
calculateElementType(const ASTContext & Context,const clang::Type * ResourceTy)226 calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {
227 using llvm::hlsl::ElementType;
228
229 // TODO: We may need to update this when we add things like ByteAddressBuffer
230 // that don't have a template parameter (or, indeed, an element type).
231 const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();
232 assert(TST && "Resource types must be template specializations");
233 ArrayRef<TemplateArgument> Args = TST->template_arguments();
234 assert(!Args.empty() && "Resource has no element type");
235
236 // At this point we have a resource with an element type, so we can assume
237 // that it's valid or we would have diagnosed the error earlier.
238 QualType ElTy = Args[0].getAsType();
239
240 // We should either have a basic type or a vector of a basic type.
241 if (const auto *VecTy = ElTy->getAs<clang::VectorType>())
242 ElTy = VecTy->getElementType();
243
244 if (ElTy->isSignedIntegerType()) {
245 switch (Context.getTypeSize(ElTy)) {
246 case 16:
247 return ElementType::I16;
248 case 32:
249 return ElementType::I32;
250 case 64:
251 return ElementType::I64;
252 }
253 } else if (ElTy->isUnsignedIntegerType()) {
254 switch (Context.getTypeSize(ElTy)) {
255 case 16:
256 return ElementType::U16;
257 case 32:
258 return ElementType::U32;
259 case 64:
260 return ElementType::U64;
261 }
262 } else if (ElTy->isSpecificBuiltinType(BuiltinType::Half))
263 return ElementType::F16;
264 else if (ElTy->isSpecificBuiltinType(BuiltinType::Float))
265 return ElementType::F32;
266 else if (ElTy->isSpecificBuiltinType(BuiltinType::Double))
267 return ElementType::F64;
268
269 // TODO: We need to handle unorm/snorm float types here once we support them
270 llvm_unreachable("Invalid element type for resource");
271 }
272
annotateHLSLResource(const VarDecl * D,GlobalVariable * GV)273 void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
274 const Type *Ty = D->getType()->getPointeeOrArrayElementType();
275 if (!Ty)
276 return;
277 const auto *RD = Ty->getAsCXXRecordDecl();
278 if (!RD)
279 return;
280 const auto *Attr = RD->getAttr<HLSLResourceAttr>();
281 if (!Attr)
282 return;
283
284 llvm::hlsl::ResourceClass RC = Attr->getResourceClass();
285 llvm::hlsl::ResourceKind RK = Attr->getResourceKind();
286 bool IsROV = Attr->getIsROV();
287 llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty);
288
289 BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
290 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
291 }
292
BufferResBinding(HLSLResourceBindingAttr * Binding)293 CGHLSLRuntime::BufferResBinding::BufferResBinding(
294 HLSLResourceBindingAttr *Binding) {
295 if (Binding) {
296 llvm::APInt RegInt(64, 0);
297 Binding->getSlot().substr(1).getAsInteger(10, RegInt);
298 Reg = RegInt.getLimitedValue();
299 llvm::APInt SpaceInt(64, 0);
300 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
301 Space = SpaceInt.getLimitedValue();
302 } else {
303 Space = 0;
304 }
305 }
306
setHLSLEntryAttributes(const FunctionDecl * FD,llvm::Function * Fn)307 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
308 const FunctionDecl *FD, llvm::Function *Fn) {
309 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
310 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
311 const StringRef ShaderAttrKindStr = "hlsl.shader";
312 Fn->addFnAttr(ShaderAttrKindStr,
313 ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
314 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
315 const StringRef NumThreadsKindStr = "hlsl.numthreads";
316 std::string NumThreadsStr =
317 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
318 NumThreadsAttr->getZ());
319 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
320 }
321 }
322
buildVectorInput(IRBuilder<> & B,Function * F,llvm::Type * Ty)323 static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
324 if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
325 Value *Result = PoisonValue::get(Ty);
326 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
327 Value *Elt = B.CreateCall(F, {B.getInt32(I)});
328 Result = B.CreateInsertElement(Result, Elt, I);
329 }
330 return Result;
331 }
332 return B.CreateCall(F, {B.getInt32(0)});
333 }
334
emitInputSemantic(IRBuilder<> & B,const ParmVarDecl & D,llvm::Type * Ty)335 llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
336 const ParmVarDecl &D,
337 llvm::Type *Ty) {
338 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
339 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
340 llvm::Function *DxGroupIndex =
341 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
342 return B.CreateCall(FunctionCallee(DxGroupIndex));
343 }
344 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
345 llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id);
346 return buildVectorInput(B, DxThreadID, Ty);
347 }
348 assert(false && "Unhandled parameter attribute");
349 return nullptr;
350 }
351
emitEntryFunction(const FunctionDecl * FD,llvm::Function * Fn)352 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
353 llvm::Function *Fn) {
354 llvm::Module &M = CGM.getModule();
355 llvm::LLVMContext &Ctx = M.getContext();
356 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
357 Function *EntryFn =
358 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
359
360 // Copy function attributes over, we have no argument or return attributes
361 // that can be valid on the real entry.
362 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
363 Fn->getAttributes().getFnAttrs());
364 EntryFn->setAttributes(NewAttrs);
365 setHLSLEntryAttributes(FD, EntryFn);
366
367 // Set the called function as internal linkage.
368 Fn->setLinkage(GlobalValue::InternalLinkage);
369
370 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
371 IRBuilder<> B(BB);
372 llvm::SmallVector<Value *> Args;
373 // FIXME: support struct parameters where semantics are on members.
374 // See: https://github.com/llvm/llvm-project/issues/57874
375 unsigned SRetOffset = 0;
376 for (const auto &Param : Fn->args()) {
377 if (Param.hasStructRetAttr()) {
378 // FIXME: support output.
379 // See: https://github.com/llvm/llvm-project/issues/57874
380 SRetOffset = 1;
381 Args.emplace_back(PoisonValue::get(Param.getType()));
382 continue;
383 }
384 const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
385 Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
386 }
387
388 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
389 (void)CI;
390 // FIXME: Handle codegen for return type semantics.
391 // See: https://github.com/llvm/llvm-project/issues/57875
392 B.CreateRetVoid();
393 }
394
gatherFunctions(SmallVectorImpl<Function * > & Fns,llvm::Module & M,bool CtorOrDtor)395 static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
396 bool CtorOrDtor) {
397 const auto *GV =
398 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
399 if (!GV)
400 return;
401 const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
402 if (!CA)
403 return;
404 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
405 // HLSL neither supports priorities or COMDat values, so we will check those
406 // in an assert but not handle them.
407
408 llvm::SmallVector<Function *> CtorFns;
409 for (const auto &Ctor : CA->operands()) {
410 if (isa<ConstantAggregateZero>(Ctor))
411 continue;
412 ConstantStruct *CS = cast<ConstantStruct>(Ctor);
413
414 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
415 "HLSL doesn't support setting priority for global ctors.");
416 assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
417 "HLSL doesn't support COMDat for global ctors.");
418 Fns.push_back(cast<Function>(CS->getOperand(1)));
419 }
420 }
421
generateGlobalCtorDtorCalls()422 void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
423 llvm::Module &M = CGM.getModule();
424 SmallVector<Function *> CtorFns;
425 SmallVector<Function *> DtorFns;
426 gatherFunctions(CtorFns, M, true);
427 gatherFunctions(DtorFns, M, false);
428
429 // Insert a call to the global constructor at the beginning of the entry block
430 // to externally exported functions. This is a bit of a hack, but HLSL allows
431 // global constructors, but doesn't support driver initialization of globals.
432 for (auto &F : M.functions()) {
433 if (!F.hasFnAttribute("hlsl.shader"))
434 continue;
435 IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
436 for (auto *Fn : CtorFns)
437 B.CreateCall(FunctionCallee(Fn));
438
439 // Insert global dtors before the terminator of the last instruction
440 B.SetInsertPoint(F.back().getTerminator());
441 for (auto *Fn : DtorFns)
442 B.CreateCall(FunctionCallee(Fn));
443 }
444
445 // No need to keep global ctors/dtors for non-lib profile after call to
446 // ctors/dtors added for entry.
447 Triple T(M.getTargetTriple());
448 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
449 if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
450 GV->eraseFromParent();
451 if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
452 GV->eraseFromParent();
453 }
454 }
455