1 //===- DXContainer.h - DXContainer file implementation ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the DXContainerFile class, which implements the ObjectFile
10 // interface for DXContainer files.
11 //
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_OBJECT_DXCONTAINER_H
16 #define LLVM_OBJECT_DXCONTAINER_H
17 
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/BinaryFormat/DXContainer.h"
21 #include "llvm/Support/Error.h"
22 #include "llvm/Support/MemoryBufferRef.h"
23 #include "llvm/TargetParser/Triple.h"
24 #include <variant>
25 
26 namespace llvm {
27 namespace object {
28 
29 namespace DirectX {
30 class PSVRuntimeInfo {
31 
32   // This class provides a view into the underlying resource array. The Resource
33   // data is little-endian encoded and may not be properly aligned to read
34   // directly from. The dereference operator creates a copy of the data and byte
35   // swaps it as appropriate.
36   struct ResourceArray {
37     StringRef Data;
38     uint32_t Stride; // size of each element in the list.
39 
40     ResourceArray() = default;
41     ResourceArray(StringRef D, size_t S) : Data(D), Stride(S) {}
42 
43     using value_type = dxbc::PSV::v2::ResourceBindInfo;
44     static constexpr uint32_t MaxStride() {
45       return static_cast<uint32_t>(sizeof(value_type));
46     }
47 
48     struct iterator {
49       StringRef Data;
50       uint32_t Stride; // size of each element in the list.
51       const char *Current;
52 
53       iterator(const ResourceArray &A, const char *C)
54           : Data(A.Data), Stride(A.Stride), Current(C) {}
55       iterator(const iterator &) = default;
56 
57       value_type operator*() {
58         // Explicitly zero the structure so that unused fields are zeroed. It is
59         // up to the user to know if the fields are used by verifying the PSV
60         // version.
61         value_type Val = {{0, 0, 0, 0}, 0, 0};
62         if (Current >= Data.end())
63           return Val;
64         memcpy(static_cast<void *>(&Val), Current,
65                std::min(Stride, MaxStride()));
66         if (sys::IsBigEndianHost)
67           Val.swapBytes();
68         return Val;
69       }
70 
71       iterator operator++() {
72         if (Current < Data.end())
73           Current += Stride;
74         return *this;
75       }
76 
77       iterator operator++(int) {
78         iterator Tmp = *this;
79         ++*this;
80         return Tmp;
81       }
82 
83       iterator operator--() {
84         if (Current > Data.begin())
85           Current -= Stride;
86         return *this;
87       }
88 
89       iterator operator--(int) {
90         iterator Tmp = *this;
91         --*this;
92         return Tmp;
93       }
94 
95       bool operator==(const iterator I) { return I.Current == Current; }
96       bool operator!=(const iterator I) { return !(*this == I); }
97     };
98 
99     iterator begin() const { return iterator(*this, Data.begin()); }
100 
101     iterator end() const { return iterator(*this, Data.end()); }
102 
103     size_t size() const { return Data.size() / Stride; }
104   };
105 
106   StringRef Data;
107   uint32_t Size;
108   using InfoStruct =
109       std::variant<std::monostate, dxbc::PSV::v0::RuntimeInfo,
110                    dxbc::PSV::v1::RuntimeInfo, dxbc::PSV::v2::RuntimeInfo>;
111   InfoStruct BasicInfo;
112   ResourceArray Resources;
113 
114 public:
115   PSVRuntimeInfo(StringRef D) : Data(D), Size(0) {}
116 
117   // Parsing depends on the shader kind
118   Error parse(uint16_t ShaderKind);
119 
120   uint32_t getSize() const { return Size; }
121   uint32_t getResourceCount() const { return Resources.size(); }
122   ResourceArray getResources() const { return Resources; }
123 
124   uint32_t getVersion() const {
125     return Size >= sizeof(dxbc::PSV::v2::RuntimeInfo)
126                ? 2
127                : (Size >= sizeof(dxbc::PSV::v1::RuntimeInfo) ? 1 : 0);
128   }
129 
130   uint32_t getResourceStride() const { return Resources.Stride; }
131 
132   const InfoStruct &getInfo() const { return BasicInfo; }
133 };
134 
135 } // namespace DirectX
136 
137 class DXContainer {
138 public:
139   using DXILData = std::pair<dxbc::ProgramHeader, const char *>;
140 
141 private:
142   DXContainer(MemoryBufferRef O);
143 
144   MemoryBufferRef Data;
145   dxbc::Header Header;
146   SmallVector<uint32_t, 4> PartOffsets;
147   std::optional<DXILData> DXIL;
148   std::optional<uint64_t> ShaderFlags;
149   std::optional<dxbc::ShaderHash> Hash;
150   std::optional<DirectX::PSVRuntimeInfo> PSVInfo;
151 
152   Error parseHeader();
153   Error parsePartOffsets();
154   Error parseDXILHeader(StringRef Part);
155   Error parseShaderFlags(StringRef Part);
156   Error parseHash(StringRef Part);
157   Error parsePSVInfo(StringRef Part);
158   friend class PartIterator;
159 
160 public:
161   // The PartIterator is a wrapper around the iterator for the PartOffsets
162   // member of the DXContainer. It contains a refernce to the container, and the
163   // current iterator value, as well as storage for a parsed part header.
164   class PartIterator {
165     const DXContainer &Container;
166     SmallVectorImpl<uint32_t>::const_iterator OffsetIt;
167     struct PartData {
168       dxbc::PartHeader Part;
169       uint32_t Offset;
170       StringRef Data;
171     } IteratorState;
172 
173     friend class DXContainer;
174 
175     PartIterator(const DXContainer &C,
176                  SmallVectorImpl<uint32_t>::const_iterator It)
177         : Container(C), OffsetIt(It) {
178       if (OffsetIt == Container.PartOffsets.end())
179         updateIteratorImpl(Container.PartOffsets.back());
180       else
181         updateIterator();
182     }
183 
184     // Updates the iterator's state data. This results in copying the part
185     // header into the iterator and handling any required byte swapping. This is
186     // called when incrementing or decrementing the iterator.
187     void updateIterator() {
188       if (OffsetIt != Container.PartOffsets.end())
189         updateIteratorImpl(*OffsetIt);
190     }
191 
192     // Implementation for updating the iterator state based on a specified
193     // offest.
194     void updateIteratorImpl(const uint32_t Offset);
195 
196   public:
197     PartIterator &operator++() {
198       if (OffsetIt == Container.PartOffsets.end())
199         return *this;
200       ++OffsetIt;
201       updateIterator();
202       return *this;
203     }
204 
205     PartIterator operator++(int) {
206       PartIterator Tmp = *this;
207       ++(*this);
208       return Tmp;
209     }
210 
211     bool operator==(const PartIterator &RHS) const {
212       return OffsetIt == RHS.OffsetIt;
213     }
214 
215     bool operator!=(const PartIterator &RHS) const {
216       return OffsetIt != RHS.OffsetIt;
217     }
218 
219     const PartData &operator*() { return IteratorState; }
220     const PartData *operator->() { return &IteratorState; }
221   };
222 
223   PartIterator begin() const {
224     return PartIterator(*this, PartOffsets.begin());
225   }
226 
227   PartIterator end() const { return PartIterator(*this, PartOffsets.end()); }
228 
229   StringRef getData() const { return Data.getBuffer(); }
230   static Expected<DXContainer> create(MemoryBufferRef Object);
231 
232   const dxbc::Header &getHeader() const { return Header; }
233 
234   const std::optional<DXILData> &getDXIL() const { return DXIL; }
235 
236   std::optional<uint64_t> getShaderFlags() const { return ShaderFlags; }
237 
238   std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; }
239 
240   const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const {
241     return PSVInfo;
242   };
243 };
244 
245 } // namespace object
246 } // namespace llvm
247 
248 #endif // LLVM_OBJECT_DXCONTAINER_H
249