1 //===- DXILMetadata.cpp - DXIL Metadata helper objects --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects for working with DXIL metadata.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILMetadata.h"
14 #include "llvm/IR/Constants.h"
15 #include "llvm/IR/IRBuilder.h"
16 #include "llvm/IR/Metadata.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/Support/VersionTuple.h"
19 #include "llvm/TargetParser/Triple.h"
20 
21 using namespace llvm;
22 using namespace llvm::dxil;
23 
24 ValidatorVersionMD::ValidatorVersionMD(Module &M)
25     : Entry(M.getOrInsertNamedMetadata("dx.valver")) {}
26 
27 void ValidatorVersionMD::update(VersionTuple ValidatorVer) {
28   auto &Ctx = Entry->getParent()->getContext();
29   IRBuilder<> B(Ctx);
30   Metadata *MDVals[2];
31   MDVals[0] = ConstantAsMetadata::get(B.getInt32(ValidatorVer.getMajor()));
32   MDVals[1] =
33       ConstantAsMetadata::get(B.getInt32(ValidatorVer.getMinor().value_or(0)));
34 
35   if (isEmpty())
36     Entry->addOperand(MDNode::get(Ctx, MDVals));
37   else
38     Entry->setOperand(0, MDNode::get(Ctx, MDVals));
39 }
40 
41 bool ValidatorVersionMD::isEmpty() { return Entry->getNumOperands() == 0; }
42 
43 static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
44   switch (Env) {
45   case Triple::Pixel:
46     return "ps";
47   case Triple::Vertex:
48     return "vs";
49   case Triple::Geometry:
50     return "gs";
51   case Triple::Hull:
52     return "hs";
53   case Triple::Domain:
54     return "ds";
55   case Triple::Compute:
56     return "cs";
57   case Triple::Library:
58     return "lib";
59   case Triple::Mesh:
60     return "ms";
61   case Triple::Amplification:
62     return "as";
63   default:
64     break;
65   }
66   llvm_unreachable("Unsupported environment for DXIL generation.");
67   return "";
68 }
69 
70 void dxil::createShaderModelMD(Module &M) {
71   NamedMDNode *Entry = M.getOrInsertNamedMetadata("dx.shaderModel");
72   Triple TT(M.getTargetTriple());
73   VersionTuple Ver = TT.getOSVersion();
74   LLVMContext &Ctx = M.getContext();
75   IRBuilder<> B(Ctx);
76 
77   Metadata *Vals[3];
78   Vals[0] = MDString::get(Ctx, getShortShaderStage(TT.getEnvironment()));
79   Vals[1] = ConstantAsMetadata::get(B.getInt32(Ver.getMajor()));
80   Vals[2] = ConstantAsMetadata::get(B.getInt32(Ver.getMinor().value_or(0)));
81   Entry->addOperand(MDNode::get(Ctx, Vals));
82 }
83 
84 static uint32_t getShaderStage(Triple::EnvironmentType Env) {
85   return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
86 }
87 
88 namespace {
89 
90 struct EntryProps {
91   Triple::EnvironmentType ShaderKind;
92   // FIXME: support more shader profiles.
93   // See https://github.com/llvm/llvm-project/issues/57927.
94   struct {
95     unsigned NumThreads[3];
96   } CS;
97 
98   EntryProps(Function &F, Triple::EnvironmentType ModuleShaderKind)
99       : ShaderKind(ModuleShaderKind) {
100 
101     if (ShaderKind == Triple::EnvironmentType::Library) {
102       Attribute EntryAttr = F.getFnAttribute("hlsl.shader");
103       StringRef EntryProfile = EntryAttr.getValueAsString();
104       Triple T("", "", "", EntryProfile);
105       ShaderKind = T.getEnvironment();
106     }
107 
108     if (ShaderKind == Triple::EnvironmentType::Compute) {
109       auto NumThreadsStr =
110           F.getFnAttribute("hlsl.numthreads").getValueAsString();
111       SmallVector<StringRef> NumThreads;
112       NumThreadsStr.split(NumThreads, ',');
113       assert(NumThreads.size() == 3 && "invalid numthreads");
114       auto Zip =
115           llvm::zip(NumThreads, MutableArrayRef<unsigned>(CS.NumThreads));
116       for (auto It : Zip) {
117         StringRef Str = std::get<0>(It);
118         APInt V;
119         [[maybe_unused]] bool Result = Str.getAsInteger(10, V);
120         assert(!Result && "Failed to parse numthreads");
121 
122         unsigned &Num = std::get<1>(It);
123         Num = V.getLimitedValue();
124       }
125     }
126   }
127 
128   MDTuple *emitDXILEntryProps(uint64_t RawShaderFlag, LLVMContext &Ctx,
129                               bool IsLib) {
130     std::vector<Metadata *> MDVals;
131 
132     if (RawShaderFlag != 0)
133       appendShaderFlags(MDVals, RawShaderFlag, Ctx);
134 
135     // Add shader kind for lib entrys.
136     if (IsLib && ShaderKind != Triple::EnvironmentType::Library)
137       appendShaderKind(MDVals, Ctx);
138 
139     if (ShaderKind == Triple::EnvironmentType::Compute)
140       appendNumThreads(MDVals, Ctx);
141     // FIXME: support more props.
142     // See https://github.com/llvm/llvm-project/issues/57948.
143     return MDNode::get(Ctx, MDVals);
144   }
145 
146   static MDTuple *emitEntryPropsForEmptyEntry(uint64_t RawShaderFlag,
147                                               LLVMContext &Ctx) {
148     if (RawShaderFlag == 0)
149       return nullptr;
150 
151     std::vector<Metadata *> MDVals;
152 
153     appendShaderFlags(MDVals, RawShaderFlag, Ctx);
154     // FIXME: support more props.
155     // See https://github.com/llvm/llvm-project/issues/57948.
156     return MDNode::get(Ctx, MDVals);
157   }
158 
159 private:
160   enum EntryPropsTag {
161     ShaderFlagsTag = 0,
162     GSStateTag,
163     DSStateTag,
164     HSStateTag,
165     NumThreadsTag,
166     AutoBindingSpaceTag,
167     RayPayloadSizeTag,
168     RayAttribSizeTag,
169     ShaderKindTag,
170     MSStateTag,
171     ASStateTag,
172     WaveSizeTag,
173     EntryRootSigTag,
174   };
175 
176   void appendNumThreads(std::vector<Metadata *> &MDVals, LLVMContext &Ctx) {
177     MDVals.emplace_back(ConstantAsMetadata::get(
178         ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsTag)));
179 
180     std::vector<Metadata *> NumThreadVals;
181     for (auto Num : ArrayRef<unsigned>(CS.NumThreads))
182       NumThreadVals.emplace_back(ConstantAsMetadata::get(
183           ConstantInt::get(Type::getInt32Ty(Ctx), Num)));
184     MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
185   }
186 
187   static void appendShaderFlags(std::vector<Metadata *> &MDVals,
188                                 uint64_t RawShaderFlag, LLVMContext &Ctx) {
189     MDVals.emplace_back(ConstantAsMetadata::get(
190         ConstantInt::get(Type::getInt32Ty(Ctx), ShaderFlagsTag)));
191     MDVals.emplace_back(ConstantAsMetadata::get(
192         ConstantInt::get(Type::getInt64Ty(Ctx), RawShaderFlag)));
193   }
194 
195   void appendShaderKind(std::vector<Metadata *> &MDVals, LLVMContext &Ctx) {
196     MDVals.emplace_back(ConstantAsMetadata::get(
197         ConstantInt::get(Type::getInt32Ty(Ctx), ShaderKindTag)));
198     MDVals.emplace_back(ConstantAsMetadata::get(
199         ConstantInt::get(Type::getInt32Ty(Ctx), getShaderStage(ShaderKind))));
200   }
201 };
202 
203 class EntryMD {
204   Function &F;
205   LLVMContext &Ctx;
206   EntryProps Props;
207 
208 public:
209   EntryMD(Function &F, Triple::EnvironmentType ModuleShaderKind)
210       : F(F), Ctx(F.getContext()), Props(F, ModuleShaderKind) {}
211 
212   MDTuple *emitEntryTuple(MDTuple *Resources, uint64_t RawShaderFlag) {
213     // FIXME: add signature for profile other than CS.
214     // See https://github.com/llvm/llvm-project/issues/57928.
215     MDTuple *Signatures = nullptr;
216     return emitDxilEntryPointTuple(
217         &F, F.getName().str(), Signatures, Resources,
218         Props.emitDXILEntryProps(RawShaderFlag, Ctx, /*IsLib*/ false), Ctx);
219   }
220 
221   MDTuple *emitEntryTupleForLib(uint64_t RawShaderFlag) {
222     // FIXME: add signature for profile other than CS.
223     // See https://github.com/llvm/llvm-project/issues/57928.
224     MDTuple *Signatures = nullptr;
225     return emitDxilEntryPointTuple(
226         &F, F.getName().str(), Signatures,
227         /*entry in lib doesn't need resources metadata*/ nullptr,
228         Props.emitDXILEntryProps(RawShaderFlag, Ctx, /*IsLib*/ true), Ctx);
229   }
230 
231   // Library will have empty entry metadata which only store the resource table
232   // metadata.
233   static MDTuple *emitEmptyEntryForLib(MDTuple *Resources,
234                                        uint64_t RawShaderFlag,
235                                        LLVMContext &Ctx) {
236     return emitDxilEntryPointTuple(
237         nullptr, "", nullptr, Resources,
238         EntryProps::emitEntryPropsForEmptyEntry(RawShaderFlag, Ctx), Ctx);
239   }
240 
241 private:
242   static MDTuple *emitDxilEntryPointTuple(Function *Fn, const std::string &Name,
243                                           MDTuple *Signatures,
244                                           MDTuple *Resources,
245                                           MDTuple *Properties,
246                                           LLVMContext &Ctx) {
247     Metadata *MDVals[5];
248     MDVals[0] = Fn ? ValueAsMetadata::get(Fn) : nullptr;
249     MDVals[1] = MDString::get(Ctx, Name.c_str());
250     MDVals[2] = Signatures;
251     MDVals[3] = Resources;
252     MDVals[4] = Properties;
253     return MDNode::get(Ctx, MDVals);
254   }
255 };
256 } // namespace
257 
258 void dxil::createEntryMD(Module &M, const uint64_t ShaderFlags) {
259   SmallVector<Function *> EntryList;
260   for (auto &F : M.functions()) {
261     if (!F.hasFnAttribute("hlsl.shader"))
262       continue;
263     EntryList.emplace_back(&F);
264   }
265 
266   auto &Ctx = M.getContext();
267   // FIXME: generate metadata for resource.
268   // See https://github.com/llvm/llvm-project/issues/57926.
269   MDTuple *MDResources = nullptr;
270   if (auto *NamedResources = M.getNamedMetadata("dx.resources"))
271     MDResources = dyn_cast<MDTuple>(NamedResources->getOperand(0));
272 
273   std::vector<MDNode *> Entries;
274   Triple T = Triple(M.getTargetTriple());
275   switch (T.getEnvironment()) {
276   case Triple::EnvironmentType::Library: {
277     // Add empty entry to put resource metadata.
278     MDTuple *EmptyEntry =
279         EntryMD::emitEmptyEntryForLib(MDResources, ShaderFlags, Ctx);
280     Entries.emplace_back(EmptyEntry);
281 
282     for (Function *Entry : EntryList) {
283       EntryMD MD(*Entry, T.getEnvironment());
284       Entries.emplace_back(MD.emitEntryTupleForLib(0));
285     }
286   } break;
287   case Triple::EnvironmentType::Compute:
288   case Triple::EnvironmentType::Amplification:
289   case Triple::EnvironmentType::Mesh:
290   case Triple::EnvironmentType::Vertex:
291   case Triple::EnvironmentType::Hull:
292   case Triple::EnvironmentType::Domain:
293   case Triple::EnvironmentType::Geometry:
294   case Triple::EnvironmentType::Pixel: {
295     assert(EntryList.size() == 1 &&
296            "non-lib profiles should only have one entry");
297     EntryMD MD(*EntryList.front(), T.getEnvironment());
298     Entries.emplace_back(MD.emitEntryTuple(MDResources, ShaderFlags));
299   } break;
300   default:
301     assert(0 && "invalid profile");
302     break;
303   }
304 
305   NamedMDNode *EntryPointsNamedMD =
306       M.getOrInsertNamedMetadata("dx.entryPoints");
307   for (auto *Entry : Entries)
308     EntryPointsNamedMD->addOperand(Entry);
309 }
310