1 //===- DXContainerYAML.cpp - DXContainer YAMLIO implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines classes for handling the YAML representation of
10 // DXContainerYAML.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ObjectYAML/DXContainerYAML.h"
15 #include "llvm/BinaryFormat/DXContainer.h"
16 
17 namespace llvm {
18 
19 // This assert is duplicated here to leave a breadcrumb of the places that need
20 // to be updated if flags grow past 64-bits.
21 static_assert((uint64_t)dxbc::FeatureFlags::NextUnusedBit <= 1ull << 63,
22               "Shader flag bits exceed enum size.");
23 
24 DXContainerYAML::ShaderFlags::ShaderFlags(uint64_t FlagData) {
25 #define SHADER_FLAG(Num, Val, Str)                                             \
26   Val = (FlagData & (uint64_t)dxbc::FeatureFlags::Val) > 0;
27 #include "llvm/BinaryFormat/DXContainerConstants.def"
28 }
29 
30 uint64_t DXContainerYAML::ShaderFlags::getEncodedFlags() {
31   uint64_t Flag = 0;
32 #define SHADER_FLAG(Num, Val, Str)                                             \
33   if (Val)                                                                     \
34     Flag |= (uint64_t)dxbc::FeatureFlags::Val;
35 #include "llvm/BinaryFormat/DXContainerConstants.def"
36   return Flag;
37 }
38 
39 DXContainerYAML::ShaderHash::ShaderHash(const dxbc::ShaderHash &Data)
40     : IncludesSource((Data.Flags & static_cast<uint32_t>(
41                                        dxbc::HashFlags::IncludesSource)) != 0),
42       Digest(16, 0) {
43   memcpy(Digest.data(), &Data.Digest[0], 16);
44 }
45 
46 DXContainerYAML::PSVInfo::PSVInfo() : Version(0) {
47   memset(&Info, 0, sizeof(Info));
48 }
49 
50 DXContainerYAML::PSVInfo::PSVInfo(const dxbc::PSV::v0::RuntimeInfo *P,
51                                   uint16_t Stage)
52     : Version(0) {
53   memset(&Info, 0, sizeof(Info));
54   memcpy(&Info, P, sizeof(dxbc::PSV::v0::RuntimeInfo));
55 
56   assert(Stage < std::numeric_limits<uint8_t>::max() &&
57          "Stage should be a very small number");
58   // We need to bring the stage in separately since it isn't part of the v1 data
59   // structure.
60   Info.ShaderStage = static_cast<uint8_t>(Stage);
61 }
62 
63 DXContainerYAML::PSVInfo::PSVInfo(const dxbc::PSV::v1::RuntimeInfo *P)
64     : Version(1) {
65   memset(&Info, 0, sizeof(Info));
66   memcpy(&Info, P, sizeof(dxbc::PSV::v1::RuntimeInfo));
67 }
68 
69 DXContainerYAML::PSVInfo::PSVInfo(const dxbc::PSV::v2::RuntimeInfo *P)
70     : Version(2) {
71   memset(&Info, 0, sizeof(Info));
72   memcpy(&Info, P, sizeof(dxbc::PSV::v2::RuntimeInfo));
73 }
74 
75 namespace yaml {
76 
77 void MappingTraits<DXContainerYAML::VersionTuple>::mapping(
78     IO &IO, DXContainerYAML::VersionTuple &Version) {
79   IO.mapRequired("Major", Version.Major);
80   IO.mapRequired("Minor", Version.Minor);
81 }
82 
83 void MappingTraits<DXContainerYAML::FileHeader>::mapping(
84     IO &IO, DXContainerYAML::FileHeader &Header) {
85   IO.mapRequired("Hash", Header.Hash);
86   IO.mapRequired("Version", Header.Version);
87   IO.mapOptional("FileSize", Header.FileSize);
88   IO.mapRequired("PartCount", Header.PartCount);
89   IO.mapOptional("PartOffsets", Header.PartOffsets);
90 }
91 
92 void MappingTraits<DXContainerYAML::DXILProgram>::mapping(
93     IO &IO, DXContainerYAML::DXILProgram &Program) {
94   IO.mapRequired("MajorVersion", Program.MajorVersion);
95   IO.mapRequired("MinorVersion", Program.MinorVersion);
96   IO.mapRequired("ShaderKind", Program.ShaderKind);
97   IO.mapOptional("Size", Program.Size);
98   IO.mapRequired("DXILMajorVersion", Program.DXILMajorVersion);
99   IO.mapRequired("DXILMinorVersion", Program.DXILMinorVersion);
100   IO.mapOptional("DXILSize", Program.DXILSize);
101   IO.mapOptional("DXIL", Program.DXIL);
102 }
103 
104 void MappingTraits<DXContainerYAML::ShaderFlags>::mapping(
105     IO &IO, DXContainerYAML::ShaderFlags &Flags) {
106 #define SHADER_FLAG(Num, Val, Str) IO.mapRequired(#Val, Flags.Val);
107 #include "llvm/BinaryFormat/DXContainerConstants.def"
108 }
109 
110 void MappingTraits<DXContainerYAML::ShaderHash>::mapping(
111     IO &IO, DXContainerYAML::ShaderHash &Hash) {
112   IO.mapRequired("IncludesSource", Hash.IncludesSource);
113   IO.mapRequired("Digest", Hash.Digest);
114 }
115 
116 void MappingTraits<DXContainerYAML::PSVInfo>::mapping(
117     IO &IO, DXContainerYAML::PSVInfo &PSV) {
118   IO.mapRequired("Version", PSV.Version);
119 
120   // Store the PSV version in the YAML context.
121   void *OldContext = IO.getContext();
122   uint32_t Version = PSV.Version;
123   IO.setContext(&Version);
124 
125   // Shader stage is only included in binaries for v1 and later, but we always
126   // include it since it simplifies parsing and file construction.
127   IO.mapRequired("ShaderStage", PSV.Info.ShaderStage);
128   PSV.mapInfoForVersion(IO);
129 
130   IO.mapRequired("ResourceStride", PSV.ResourceStride);
131   IO.mapRequired("Resources", PSV.Resources);
132 
133   // Restore the YAML context.
134   IO.setContext(OldContext);
135 }
136 
137 void MappingTraits<DXContainerYAML::Part>::mapping(IO &IO,
138                                                    DXContainerYAML::Part &P) {
139   IO.mapRequired("Name", P.Name);
140   IO.mapRequired("Size", P.Size);
141   IO.mapOptional("Program", P.Program);
142   IO.mapOptional("Flags", P.Flags);
143   IO.mapOptional("Hash", P.Hash);
144   IO.mapOptional("PSVInfo", P.Info);
145 }
146 
147 void MappingTraits<DXContainerYAML::Object>::mapping(
148     IO &IO, DXContainerYAML::Object &Obj) {
149   IO.mapTag("!dxcontainer", true);
150   IO.mapRequired("Header", Obj.Header);
151   IO.mapRequired("Parts", Obj.Parts);
152 }
153 
154 void MappingTraits<DXContainerYAML::ResourceBindInfo>::mapping(
155     IO &IO, DXContainerYAML::ResourceBindInfo &Res) {
156   IO.mapRequired("Type", Res.Type);
157   IO.mapRequired("Space", Res.Space);
158   IO.mapRequired("LowerBound", Res.LowerBound);
159   IO.mapRequired("UpperBound", Res.UpperBound);
160 
161   const uint32_t *PSVVersion = static_cast<uint32_t *>(IO.getContext());
162   if (*PSVVersion < 2)
163     return;
164 
165   IO.mapRequired("Kind", Res.Kind);
166   IO.mapRequired("Flags", Res.Flags);
167 }
168 
169 } // namespace yaml
170 
171 void DXContainerYAML::PSVInfo::mapInfoForVersion(yaml::IO &IO) {
172   dxbc::PipelinePSVInfo &StageInfo = Info.StageInfo;
173   Triple::EnvironmentType Stage = dxbc::getShaderStage(Info.ShaderStage);
174 
175   switch (Stage) {
176   case Triple::EnvironmentType::Pixel:
177     IO.mapRequired("DepthOutput", StageInfo.PS.DepthOutput);
178     IO.mapRequired("SampleFrequency", StageInfo.PS.SampleFrequency);
179     break;
180   case Triple::EnvironmentType::Vertex:
181     IO.mapRequired("OutputPositionPresent", StageInfo.VS.OutputPositionPresent);
182     break;
183   case Triple::EnvironmentType::Geometry:
184     IO.mapRequired("InputPrimitive", StageInfo.GS.InputPrimitive);
185     IO.mapRequired("OutputTopology", StageInfo.GS.OutputTopology);
186     IO.mapRequired("OutputStreamMask", StageInfo.GS.OutputStreamMask);
187     IO.mapRequired("OutputPositionPresent", StageInfo.GS.OutputPositionPresent);
188     break;
189   case Triple::EnvironmentType::Hull:
190     IO.mapRequired("InputControlPointCount",
191                    StageInfo.HS.InputControlPointCount);
192     IO.mapRequired("OutputControlPointCount",
193                    StageInfo.HS.OutputControlPointCount);
194     IO.mapRequired("TessellatorDomain", StageInfo.HS.TessellatorDomain);
195     IO.mapRequired("TessellatorOutputPrimitive",
196                    StageInfo.HS.TessellatorOutputPrimitive);
197     break;
198   case Triple::EnvironmentType::Domain:
199     IO.mapRequired("InputControlPointCount",
200                    StageInfo.DS.InputControlPointCount);
201     IO.mapRequired("OutputPositionPresent", StageInfo.DS.OutputPositionPresent);
202     IO.mapRequired("TessellatorDomain", StageInfo.DS.TessellatorDomain);
203     break;
204   case Triple::EnvironmentType::Mesh:
205     IO.mapRequired("GroupSharedBytesUsed", StageInfo.MS.GroupSharedBytesUsed);
206     IO.mapRequired("GroupSharedBytesDependentOnViewID",
207                    StageInfo.MS.GroupSharedBytesDependentOnViewID);
208     IO.mapRequired("PayloadSizeInBytes", StageInfo.MS.PayloadSizeInBytes);
209     IO.mapRequired("MaxOutputVertices", StageInfo.MS.MaxOutputVertices);
210     IO.mapRequired("MaxOutputPrimitives", StageInfo.MS.MaxOutputPrimitives);
211     break;
212   case Triple::EnvironmentType::Amplification:
213     IO.mapRequired("PayloadSizeInBytes", StageInfo.AS.PayloadSizeInBytes);
214     break;
215   default:
216     break;
217   }
218 
219   IO.mapRequired("MinimumWaveLaneCount", Info.MinimumWaveLaneCount);
220   IO.mapRequired("MaximumWaveLaneCount", Info.MaximumWaveLaneCount);
221 
222   if (Version == 0)
223     return;
224 
225   IO.mapRequired("UsesViewID", Info.UsesViewID);
226 
227   switch (Stage) {
228   case Triple::EnvironmentType::Geometry:
229     IO.mapRequired("MaxVertexCount", Info.GeomData.MaxVertexCount);
230     break;
231   case Triple::EnvironmentType::Hull:
232   case Triple::EnvironmentType::Domain:
233     IO.mapRequired("SigPatchConstOrPrimVectors",
234                    Info.GeomData.SigPatchConstOrPrimVectors);
235     break;
236   case Triple::EnvironmentType::Mesh:
237     IO.mapRequired("SigPrimVectors", Info.GeomData.MeshInfo.SigPrimVectors);
238     IO.mapRequired("MeshOutputTopology",
239                    Info.GeomData.MeshInfo.MeshOutputTopology);
240     break;
241   default:
242     break;
243   }
244 
245   IO.mapRequired("SigInputElements", Info.SigInputElements);
246   IO.mapRequired("SigOutputElements", Info.SigOutputElements);
247   IO.mapRequired("SigPatchConstOrPrimElements",
248                  Info.SigPatchConstOrPrimElements);
249   IO.mapRequired("SigInputVectors", Info.SigInputVectors);
250   MutableArrayRef<uint8_t> Vec(Info.SigOutputVectors);
251   IO.mapRequired("SigOutputVectors", Vec);
252 
253   if (Version == 1)
254     return;
255 
256   IO.mapRequired("NumThreadsX", Info.NumThreadsX);
257   IO.mapRequired("NumThreadsY", Info.NumThreadsY);
258   IO.mapRequired("NumThreadsZ", Info.NumThreadsZ);
259 }
260 
261 } // namespace llvm
262