1 //===- DXILResource.cpp - DXIL Resource helper objects --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects for working with DXIL Resources.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILResource.h"
14 #include "CBufferDataLayout.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include "llvm/IR/IRBuilder.h"
17 #include "llvm/IR/Metadata.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/Format.h"
21 
22 using namespace llvm;
23 using namespace llvm::dxil;
24 using namespace llvm::hlsl;
25 
26 template <typename T> void ResourceTable<T>::collect(Module &M) {
27   NamedMDNode *Entry = M.getNamedMetadata(MDName);
28   if (!Entry || Entry->getNumOperands() == 0)
29     return;
30 
31   uint32_t Counter = 0;
32   for (auto *Res : Entry->operands()) {
33     Data.push_back(T(Counter++, FrontendResource(cast<MDNode>(Res))));
34   }
35 }
36 
37 template <> void ResourceTable<ConstantBuffer>::collect(Module &M) {
38   NamedMDNode *Entry = M.getNamedMetadata(MDName);
39   if (!Entry || Entry->getNumOperands() == 0)
40     return;
41 
42   uint32_t Counter = 0;
43   for (auto *Res : Entry->operands()) {
44     Data.push_back(
45         ConstantBuffer(Counter++, FrontendResource(cast<MDNode>(Res))));
46   }
47   // FIXME: share CBufferDataLayout with CBuffer load lowering.
48   //   See https://github.com/llvm/llvm-project/issues/58381
49   CBufferDataLayout CBDL(M.getDataLayout(), /*IsLegacy*/ true);
50   for (auto &CB : Data)
51     CB.setSize(CBDL);
52 }
53 
54 void Resources::collect(Module &M) {
55   UAVs.collect(M);
56   CBuffers.collect(M);
57 }
58 
59 ResourceBase::ResourceBase(uint32_t I, FrontendResource R)
60     : ID(I), GV(R.getGlobalVariable()), Name(""), Space(R.getSpace()),
61       LowerBound(R.getResourceIndex()), RangeSize(1) {
62   if (auto *ArrTy = dyn_cast<ArrayType>(GV->getValueType()))
63     RangeSize = ArrTy->getNumElements();
64 }
65 
66 StringRef ResourceBase::getComponentTypeName(ComponentType CompType) {
67   switch (CompType) {
68   case ComponentType::LastEntry:
69   case ComponentType::Invalid:
70     return "invalid";
71   case ComponentType::I1:
72     return "i1";
73   case ComponentType::I16:
74     return "i16";
75   case ComponentType::U16:
76     return "u16";
77   case ComponentType::I32:
78     return "i32";
79   case ComponentType::U32:
80     return "u32";
81   case ComponentType::I64:
82     return "i64";
83   case ComponentType::U64:
84     return "u64";
85   case ComponentType::F16:
86     return "f16";
87   case ComponentType::F32:
88     return "f32";
89   case ComponentType::F64:
90     return "f64";
91   case ComponentType::SNormF16:
92     return "snorm_f16";
93   case ComponentType::UNormF16:
94     return "unorm_f16";
95   case ComponentType::SNormF32:
96     return "snorm_f32";
97   case ComponentType::UNormF32:
98     return "unorm_f32";
99   case ComponentType::SNormF64:
100     return "snorm_f64";
101   case ComponentType::UNormF64:
102     return "unorm_f64";
103   case ComponentType::PackedS8x32:
104     return "p32i8";
105   case ComponentType::PackedU8x32:
106     return "p32u8";
107   }
108   llvm_unreachable("All ComponentType enums are handled in switch");
109 }
110 
111 void ResourceBase::printComponentType(Kinds Kind, ComponentType CompType,
112                                       unsigned Alignment, raw_ostream &OS) {
113   switch (Kind) {
114   default:
115     // TODO: add vector size.
116     OS << right_justify(getComponentTypeName(CompType), Alignment);
117     break;
118   case Kinds::RawBuffer:
119     OS << right_justify("byte", Alignment);
120     break;
121   case Kinds::StructuredBuffer:
122     OS << right_justify("struct", Alignment);
123     break;
124   case Kinds::CBuffer:
125   case Kinds::Sampler:
126     OS << right_justify("NA", Alignment);
127     break;
128   case Kinds::Invalid:
129   case Kinds::NumEntries:
130     break;
131   }
132 }
133 
134 StringRef ResourceBase::getKindName(Kinds Kind) {
135   switch (Kind) {
136   case Kinds::NumEntries:
137   case Kinds::Invalid:
138     return "invalid";
139   case Kinds::Texture1D:
140     return "1d";
141   case Kinds::Texture2D:
142     return "2d";
143   case Kinds::Texture2DMS:
144     return "2dMS";
145   case Kinds::Texture3D:
146     return "3d";
147   case Kinds::TextureCube:
148     return "cube";
149   case Kinds::Texture1DArray:
150     return "1darray";
151   case Kinds::Texture2DArray:
152     return "2darray";
153   case Kinds::Texture2DMSArray:
154     return "2darrayMS";
155   case Kinds::TextureCubeArray:
156     return "cubearray";
157   case Kinds::TypedBuffer:
158     return "buf";
159   case Kinds::RawBuffer:
160     return "rawbuf";
161   case Kinds::StructuredBuffer:
162     return "structbuf";
163   case Kinds::CBuffer:
164     return "cbuffer";
165   case Kinds::Sampler:
166     return "sampler";
167   case Kinds::TBuffer:
168     return "tbuffer";
169   case Kinds::RTAccelerationStructure:
170     return "ras";
171   case Kinds::FeedbackTexture2D:
172     return "fbtex2d";
173   case Kinds::FeedbackTexture2DArray:
174     return "fbtex2darray";
175   }
176   llvm_unreachable("All Kinds enums are handled in switch");
177 }
178 
179 void ResourceBase::printKind(Kinds Kind, unsigned Alignment, raw_ostream &OS,
180                              bool SRV, bool HasCounter, uint32_t SampleCount) {
181   switch (Kind) {
182   default:
183     OS << right_justify(getKindName(Kind), Alignment);
184     break;
185 
186   case Kinds::RawBuffer:
187   case Kinds::StructuredBuffer:
188     if (SRV)
189       OS << right_justify("r/o", Alignment);
190     else {
191       if (!HasCounter)
192         OS << right_justify("r/w", Alignment);
193       else
194         OS << right_justify("r/w+cnt", Alignment);
195     }
196     break;
197   case Kinds::TypedBuffer:
198     OS << right_justify("buf", Alignment);
199     break;
200   case Kinds::Texture2DMS:
201   case Kinds::Texture2DMSArray: {
202     std::string DimName = getKindName(Kind).str();
203     if (SampleCount)
204       DimName += std::to_string(SampleCount);
205     OS << right_justify(DimName, Alignment);
206   } break;
207   case Kinds::CBuffer:
208   case Kinds::Sampler:
209     OS << right_justify("NA", Alignment);
210     break;
211   case Kinds::Invalid:
212   case Kinds::NumEntries:
213     break;
214   }
215 }
216 
217 void ResourceBase::print(raw_ostream &OS, StringRef IDPrefix,
218                          StringRef BindingPrefix) const {
219   std::string ResID = IDPrefix.str();
220   ResID += std::to_string(ID);
221   OS << right_justify(ResID, 8);
222 
223   std::string Bind = BindingPrefix.str();
224   Bind += std::to_string(LowerBound);
225   if (Space)
226     Bind += ",space" + std::to_string(Space);
227 
228   OS << right_justify(Bind, 15);
229   if (RangeSize != UINT_MAX)
230     OS << right_justify(std::to_string(RangeSize), 6) << "\n";
231   else
232     OS << right_justify("unbounded", 6) << "\n";
233 }
234 
235 UAVResource::UAVResource(uint32_t I, FrontendResource R)
236     : ResourceBase(I, R),
237       Shape(static_cast<ResourceBase::Kinds>(R.getResourceKind())),
238       GloballyCoherent(false), HasCounter(false), IsROV(false), ExtProps() {
239   parseSourceType(R.getSourceType());
240 }
241 
242 void UAVResource::print(raw_ostream &OS) const {
243   OS << "; " << left_justify(Name, 31);
244 
245   OS << right_justify("UAV", 10);
246 
247   printComponentType(
248       Shape, ExtProps.ElementType.value_or(ComponentType::Invalid), 8, OS);
249 
250   // FIXME: support SampleCount.
251   // See https://github.com/llvm/llvm-project/issues/58175
252   printKind(Shape, 12, OS, /*SRV*/ false, HasCounter);
253   // Print the binding part.
254   ResourceBase::print(OS, "U", "u");
255 }
256 
257 // FIXME: Capture this in HLSL source. I would go do this right now, but I want
258 // to get this in first so that I can make sure to capture all the extra
259 // information we need to remove the source type string from here (See issue:
260 // https://github.com/llvm/llvm-project/issues/57991).
261 void UAVResource::parseSourceType(StringRef S) {
262   IsROV = S.startswith("RasterizerOrdered");
263   if (IsROV)
264     S = S.substr(strlen("RasterizerOrdered"));
265   if (S.startswith("RW"))
266     S = S.substr(strlen("RW"));
267 
268   // Note: I'm deliberately not handling any of the Texture buffer types at the
269   // moment. I want to resolve the issue above before adding Texture or Sampler
270   // support.
271   Shape = StringSwitch<ResourceBase::Kinds>(S)
272               .StartsWith("Buffer<", Kinds::TypedBuffer)
273               .StartsWith("ByteAddressBuffer<", Kinds::RawBuffer)
274               .StartsWith("StructuredBuffer<", Kinds::StructuredBuffer)
275               .Default(Kinds::Invalid);
276   assert(Shape != Kinds::Invalid && "Unsupported buffer type");
277 
278   S = S.substr(S.find("<") + 1);
279 
280   constexpr size_t PrefixLen = StringRef("vector<").size();
281   if (S.startswith("vector<"))
282     S = S.substr(PrefixLen, S.find(",") - PrefixLen);
283   else
284     S = S.substr(0, S.find(">"));
285 
286   ComponentType ElTy = StringSwitch<ResourceBase::ComponentType>(S)
287                            .Case("bool", ComponentType::I1)
288                            .Case("int16_t", ComponentType::I16)
289                            .Case("uint16_t", ComponentType::U16)
290                            .Case("int32_t", ComponentType::I32)
291                            .Case("uint32_t", ComponentType::U32)
292                            .Case("int64_t", ComponentType::I64)
293                            .Case("uint64_t", ComponentType::U64)
294                            .Case("half", ComponentType::F16)
295                            .Case("float", ComponentType::F32)
296                            .Case("double", ComponentType::F64)
297                            .Default(ComponentType::Invalid);
298   if (ElTy != ComponentType::Invalid)
299     ExtProps.ElementType = ElTy;
300 }
301 
302 ConstantBuffer::ConstantBuffer(uint32_t I, hlsl::FrontendResource R)
303     : ResourceBase(I, R) {}
304 
305 void ConstantBuffer::setSize(CBufferDataLayout &DL) {
306   CBufferSizeInBytes = DL.getTypeAllocSizeInBytes(GV->getValueType());
307 }
308 
309 void ConstantBuffer::print(raw_ostream &OS) const {
310   OS << "; " << left_justify(Name, 31);
311 
312   OS << right_justify("cbuffer", 10);
313 
314   printComponentType(Kinds::CBuffer, ComponentType::Invalid, 8, OS);
315 
316   printKind(Kinds::CBuffer, 12, OS, /*SRV*/ false, /*HasCounter*/ false);
317   // Print the binding part.
318   ResourceBase::print(OS, "CB", "cb");
319 }
320 
321 template <typename T> void ResourceTable<T>::print(raw_ostream &OS) const {
322   for (auto &Res : Data)
323     Res.print(OS);
324 }
325 
326 MDNode *ResourceBase::ExtendedProperties::write(LLVMContext &Ctx) const {
327   IRBuilder<> B(Ctx);
328   SmallVector<Metadata *> Entries;
329   if (ElementType) {
330     Entries.emplace_back(
331         ConstantAsMetadata::get(B.getInt32(TypedBufferElementType)));
332     Entries.emplace_back(ConstantAsMetadata::get(
333         B.getInt32(static_cast<uint32_t>(*ElementType))));
334   }
335   if (Entries.empty())
336     return nullptr;
337   return MDNode::get(Ctx, Entries);
338 }
339 
340 void ResourceBase::write(LLVMContext &Ctx,
341                          MutableArrayRef<Metadata *> Entries) const {
342   IRBuilder<> B(Ctx);
343   Entries[0] = ConstantAsMetadata::get(B.getInt32(ID));
344   Entries[1] = ConstantAsMetadata::get(GV);
345   Entries[2] = MDString::get(Ctx, Name);
346   Entries[3] = ConstantAsMetadata::get(B.getInt32(Space));
347   Entries[4] = ConstantAsMetadata::get(B.getInt32(LowerBound));
348   Entries[5] = ConstantAsMetadata::get(B.getInt32(RangeSize));
349 }
350 
351 MDNode *UAVResource::write() const {
352   auto &Ctx = GV->getContext();
353   IRBuilder<> B(Ctx);
354   Metadata *Entries[11];
355   ResourceBase::write(Ctx, Entries);
356   Entries[6] =
357       ConstantAsMetadata::get(B.getInt32(static_cast<uint32_t>(Shape)));
358   Entries[7] = ConstantAsMetadata::get(B.getInt1(GloballyCoherent));
359   Entries[8] = ConstantAsMetadata::get(B.getInt1(HasCounter));
360   Entries[9] = ConstantAsMetadata::get(B.getInt1(IsROV));
361   Entries[10] = ExtProps.write(Ctx);
362   return MDNode::get(Ctx, Entries);
363 }
364 
365 MDNode *ConstantBuffer::write() const {
366   auto &Ctx = GV->getContext();
367   IRBuilder<> B(Ctx);
368   Metadata *Entries[7];
369   ResourceBase::write(Ctx, Entries);
370 
371   Entries[6] = ConstantAsMetadata::get(B.getInt32(CBufferSizeInBytes));
372   return MDNode::get(Ctx, Entries);
373 }
374 
375 template <typename T> MDNode *ResourceTable<T>::write(Module &M) const {
376   if (Data.empty())
377     return nullptr;
378   SmallVector<Metadata *> MDs;
379   for (auto &Res : Data)
380     MDs.emplace_back(Res.write());
381 
382   NamedMDNode *Entry = M.getNamedMetadata(MDName);
383   if (Entry)
384     Entry->eraseFromParent();
385 
386   return MDNode::get(M.getContext(), MDs);
387 }
388 
389 void Resources::write(Module &M) const {
390   Metadata *ResourceMDs[4] = {nullptr, nullptr, nullptr, nullptr};
391 
392   ResourceMDs[1] = UAVs.write(M);
393 
394   ResourceMDs[2] = CBuffers.write(M);
395 
396   bool HasResource = ResourceMDs[0] != nullptr || ResourceMDs[1] != nullptr ||
397                      ResourceMDs[2] != nullptr || ResourceMDs[3] != nullptr;
398 
399   if (HasResource) {
400     NamedMDNode *DXResMD = M.getOrInsertNamedMetadata("dx.resources");
401     DXResMD->addOperand(MDNode::get(M.getContext(), ResourceMDs));
402   }
403 
404   NamedMDNode *Entry = M.getNamedMetadata("hlsl.uavs");
405   if (Entry)
406     Entry->eraseFromParent();
407 }
408 
409 void Resources::print(raw_ostream &O) const {
410   O << ";\n"
411     << "; Resource Bindings:\n"
412     << ";\n"
413     << "; Name                                 Type  Format         Dim      "
414        "ID      HLSL Bind  Count\n"
415     << "; ------------------------------ ---------- ------- ----------- "
416        "------- -------------- ------\n";
417 
418   CBuffers.print(O);
419   UAVs.print(O);
420 }
421 
422 void Resources::dump() const { print(dbgs()); }
423