1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/StringSwitch.h"
18 #include "llvm/BinaryFormat/MsgPackDocument.h"
19 
20 #include <utility>
21 
22 namespace llvm {
23 namespace AMDGPU {
24 namespace HSAMD {
25 namespace V3 {
26 
27 bool MetadataVerifier::verifyScalar(
28     msgpack::DocNode &Node, msgpack::Type SKind,
29     function_ref<bool(msgpack::DocNode &)> verifyValue) {
30   if (!Node.isScalar())
31     return false;
32   if (Node.getKind() != SKind) {
33     if (Strict)
34       return false;
35     // If we are not strict, we interpret string values as "implicitly typed"
36     // and attempt to coerce them to the expected type here.
37     if (Node.getKind() != msgpack::Type::String)
38       return false;
39     StringRef StringValue = Node.getString();
40     Node.fromString(StringValue);
41     if (Node.getKind() != SKind)
42       return false;
43   }
44   if (verifyValue)
45     return verifyValue(Node);
46   return true;
47 }
48 
49 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
50   if (!verifyScalar(Node, msgpack::Type::UInt))
51     if (!verifyScalar(Node, msgpack::Type::Int))
52       return false;
53   return true;
54 }
55 
56 bool MetadataVerifier::verifyArray(
57     msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
58     std::optional<size_t> Size) {
59   if (!Node.isArray())
60     return false;
61   auto &Array = Node.getArray();
62   if (Size && Array.size() != *Size)
63     return false;
64   return llvm::all_of(Array, verifyNode);
65 }
66 
67 bool MetadataVerifier::verifyEntry(
68     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
69     function_ref<bool(msgpack::DocNode &)> verifyNode) {
70   auto Entry = MapNode.find(Key);
71   if (Entry == MapNode.end())
72     return !Required;
73   return verifyNode(Entry->second);
74 }
75 
76 bool MetadataVerifier::verifyScalarEntry(
77     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
78     msgpack::Type SKind,
79     function_ref<bool(msgpack::DocNode &)> verifyValue) {
80   return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
81     return verifyScalar(Node, SKind, verifyValue);
82   });
83 }
84 
85 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
86                                           StringRef Key, bool Required) {
87   return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
88     return verifyInteger(Node);
89   });
90 }
91 
92 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
93   if (!Node.isMap())
94     return false;
95   auto &ArgsMap = Node.getMap();
96 
97   if (!verifyScalarEntry(ArgsMap, ".name", false,
98                          msgpack::Type::String))
99     return false;
100   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
101                          msgpack::Type::String))
102     return false;
103   if (!verifyIntegerEntry(ArgsMap, ".size", true))
104     return false;
105   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
106     return false;
107   if (!verifyScalarEntry(ArgsMap, ".value_kind", true, msgpack::Type::String,
108                          [](msgpack::DocNode &SNode) {
109                            return StringSwitch<bool>(SNode.getString())
110                                .Case("by_value", true)
111                                .Case("global_buffer", true)
112                                .Case("dynamic_shared_pointer", true)
113                                .Case("sampler", true)
114                                .Case("image", true)
115                                .Case("pipe", true)
116                                .Case("queue", true)
117                                .Case("hidden_block_count_x", true)
118                                .Case("hidden_block_count_y", true)
119                                .Case("hidden_block_count_z", true)
120                                .Case("hidden_group_size_x", true)
121                                .Case("hidden_group_size_y", true)
122                                .Case("hidden_group_size_z", true)
123                                .Case("hidden_remainder_x", true)
124                                .Case("hidden_remainder_y", true)
125                                .Case("hidden_remainder_z", true)
126                                .Case("hidden_global_offset_x", true)
127                                .Case("hidden_global_offset_y", true)
128                                .Case("hidden_global_offset_z", true)
129                                .Case("hidden_grid_dims", true)
130                                .Case("hidden_none", true)
131                                .Case("hidden_printf_buffer", true)
132                                .Case("hidden_hostcall_buffer", true)
133                                .Case("hidden_heap_v1", true)
134                                .Case("hidden_default_queue", true)
135                                .Case("hidden_completion_action", true)
136                                .Case("hidden_multigrid_sync_arg", true)
137                                .Case("hidden_private_base", true)
138                                .Case("hidden_shared_base", true)
139                                .Case("hidden_queue_ptr", true)
140                                .Default(false);
141                          }))
142     return false;
143   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
144     return false;
145   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
146                          msgpack::Type::String,
147                          [](msgpack::DocNode &SNode) {
148                            return StringSwitch<bool>(SNode.getString())
149                                .Case("private", true)
150                                .Case("global", true)
151                                .Case("constant", true)
152                                .Case("local", true)
153                                .Case("generic", true)
154                                .Case("region", true)
155                                .Default(false);
156                          }))
157     return false;
158   if (!verifyScalarEntry(ArgsMap, ".access", false,
159                          msgpack::Type::String,
160                          [](msgpack::DocNode &SNode) {
161                            return StringSwitch<bool>(SNode.getString())
162                                .Case("read_only", true)
163                                .Case("write_only", true)
164                                .Case("read_write", true)
165                                .Default(false);
166                          }))
167     return false;
168   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
169                          msgpack::Type::String,
170                          [](msgpack::DocNode &SNode) {
171                            return StringSwitch<bool>(SNode.getString())
172                                .Case("read_only", true)
173                                .Case("write_only", true)
174                                .Case("read_write", true)
175                                .Default(false);
176                          }))
177     return false;
178   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
179                          msgpack::Type::Boolean))
180     return false;
181   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
182                          msgpack::Type::Boolean))
183     return false;
184   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
185                          msgpack::Type::Boolean))
186     return false;
187   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
188                          msgpack::Type::Boolean))
189     return false;
190 
191   return true;
192 }
193 
194 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
195   if (!Node.isMap())
196     return false;
197   auto &KernelMap = Node.getMap();
198 
199   if (!verifyScalarEntry(KernelMap, ".name", true,
200                          msgpack::Type::String))
201     return false;
202   if (!verifyScalarEntry(KernelMap, ".symbol", true,
203                          msgpack::Type::String))
204     return false;
205   if (!verifyScalarEntry(KernelMap, ".language", false,
206                          msgpack::Type::String,
207                          [](msgpack::DocNode &SNode) {
208                            return StringSwitch<bool>(SNode.getString())
209                                .Case("OpenCL C", true)
210                                .Case("OpenCL C++", true)
211                                .Case("HCC", true)
212                                .Case("HIP", true)
213                                .Case("OpenMP", true)
214                                .Case("Assembler", true)
215                                .Default(false);
216                          }))
217     return false;
218   if (!verifyEntry(
219           KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
220             return verifyArray(
221                 Node,
222                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
223           }))
224     return false;
225   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
226         return verifyArray(Node, [this](msgpack::DocNode &Node) {
227           return verifyKernelArgs(Node);
228         });
229       }))
230     return false;
231   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
232                    [this](msgpack::DocNode &Node) {
233                      return verifyArray(Node,
234                                         [this](msgpack::DocNode &Node) {
235                                           return verifyInteger(Node);
236                                         },
237                                         3);
238                    }))
239     return false;
240   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
241                    [this](msgpack::DocNode &Node) {
242                      return verifyArray(Node,
243                                         [this](msgpack::DocNode &Node) {
244                                           return verifyInteger(Node);
245                                         },
246                                         3);
247                    }))
248     return false;
249   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
250                          msgpack::Type::String))
251     return false;
252   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
253                          msgpack::Type::String))
254     return false;
255   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
256     return false;
257   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
258     return false;
259   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
260     return false;
261   if (!verifyScalarEntry(KernelMap, ".uses_dynamic_stack", false,
262                          msgpack::Type::Boolean))
263     return false;
264   if (!verifyIntegerEntry(KernelMap, ".workgroup_processor_mode", false))
265     return false;
266   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
267     return false;
268   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
269     return false;
270   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
271     return false;
272   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
273     return false;
274   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
275     return false;
276   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
277     return false;
278   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
279     return false;
280   if (!verifyIntegerEntry(KernelMap, ".uniform_work_group_size", false))
281     return false;
282 
283 
284   return true;
285 }
286 
287 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
288   if (!HSAMetadataRoot.isMap())
289     return false;
290   auto &RootMap = HSAMetadataRoot.getMap();
291 
292   if (!verifyEntry(
293           RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
294             return verifyArray(
295                 Node,
296                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
297           }))
298     return false;
299   if (!verifyEntry(
300           RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
301             return verifyArray(Node, [this](msgpack::DocNode &Node) {
302               return verifyScalar(Node, msgpack::Type::String);
303             });
304           }))
305     return false;
306   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
307                    [this](msgpack::DocNode &Node) {
308                      return verifyArray(Node, [this](msgpack::DocNode &Node) {
309                        return verifyKernel(Node);
310                      });
311                    }))
312     return false;
313 
314   return true;
315 }
316 
317 } // end namespace V3
318 } // end namespace HSAMD
319 } // end namespace AMDGPU
320 } // end namespace llvm
321