1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/StringSwitch.h"
18 #include "llvm/BinaryFormat/MsgPackDocument.h"
19 
20 #include <map>
21 #include <utility>
22 
23 namespace llvm {
24 namespace AMDGPU {
25 namespace HSAMD {
26 namespace V3 {
27 
28 bool MetadataVerifier::verifyScalar(
29     msgpack::DocNode &Node, msgpack::Type SKind,
30     function_ref<bool(msgpack::DocNode &)> verifyValue) {
31   if (!Node.isScalar())
32     return false;
33   if (Node.getKind() != SKind) {
34     if (Strict)
35       return false;
36     // If we are not strict, we interpret string values as "implicitly typed"
37     // and attempt to coerce them to the expected type here.
38     if (Node.getKind() != msgpack::Type::String)
39       return false;
40     StringRef StringValue = Node.getString();
41     Node.fromString(StringValue);
42     if (Node.getKind() != SKind)
43       return false;
44   }
45   if (verifyValue)
46     return verifyValue(Node);
47   return true;
48 }
49 
50 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
51   if (!verifyScalar(Node, msgpack::Type::UInt))
52     if (!verifyScalar(Node, msgpack::Type::Int))
53       return false;
54   return true;
55 }
56 
57 bool MetadataVerifier::verifyArray(
58     msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
59     std::optional<size_t> Size) {
60   if (!Node.isArray())
61     return false;
62   auto &Array = Node.getArray();
63   if (Size && Array.size() != *Size)
64     return false;
65   return llvm::all_of(Array, verifyNode);
66 }
67 
68 bool MetadataVerifier::verifyEntry(
69     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
70     function_ref<bool(msgpack::DocNode &)> verifyNode) {
71   auto Entry = MapNode.find(Key);
72   if (Entry == MapNode.end())
73     return !Required;
74   return verifyNode(Entry->second);
75 }
76 
77 bool MetadataVerifier::verifyScalarEntry(
78     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
79     msgpack::Type SKind,
80     function_ref<bool(msgpack::DocNode &)> verifyValue) {
81   return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
82     return verifyScalar(Node, SKind, verifyValue);
83   });
84 }
85 
86 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
87                                           StringRef Key, bool Required) {
88   return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
89     return verifyInteger(Node);
90   });
91 }
92 
93 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
94   if (!Node.isMap())
95     return false;
96   auto &ArgsMap = Node.getMap();
97 
98   if (!verifyScalarEntry(ArgsMap, ".name", false,
99                          msgpack::Type::String))
100     return false;
101   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
102                          msgpack::Type::String))
103     return false;
104   if (!verifyIntegerEntry(ArgsMap, ".size", true))
105     return false;
106   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
107     return false;
108   if (!verifyScalarEntry(ArgsMap, ".value_kind", true, msgpack::Type::String,
109                          [](msgpack::DocNode &SNode) {
110                            return StringSwitch<bool>(SNode.getString())
111                                .Case("by_value", true)
112                                .Case("global_buffer", true)
113                                .Case("dynamic_shared_pointer", true)
114                                .Case("sampler", true)
115                                .Case("image", true)
116                                .Case("pipe", true)
117                                .Case("queue", true)
118                                .Case("hidden_block_count_x", true)
119                                .Case("hidden_block_count_y", true)
120                                .Case("hidden_block_count_z", true)
121                                .Case("hidden_group_size_x", true)
122                                .Case("hidden_group_size_y", true)
123                                .Case("hidden_group_size_z", true)
124                                .Case("hidden_remainder_x", true)
125                                .Case("hidden_remainder_y", true)
126                                .Case("hidden_remainder_z", true)
127                                .Case("hidden_global_offset_x", true)
128                                .Case("hidden_global_offset_y", true)
129                                .Case("hidden_global_offset_z", true)
130                                .Case("hidden_grid_dims", true)
131                                .Case("hidden_none", true)
132                                .Case("hidden_printf_buffer", true)
133                                .Case("hidden_hostcall_buffer", true)
134                                .Case("hidden_heap_v1", true)
135                                .Case("hidden_default_queue", true)
136                                .Case("hidden_completion_action", true)
137                                .Case("hidden_multigrid_sync_arg", true)
138                                .Case("hidden_private_base", true)
139                                .Case("hidden_shared_base", true)
140                                .Case("hidden_queue_ptr", true)
141                                .Default(false);
142                          }))
143     return false;
144   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
145     return false;
146   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
147                          msgpack::Type::String,
148                          [](msgpack::DocNode &SNode) {
149                            return StringSwitch<bool>(SNode.getString())
150                                .Case("private", true)
151                                .Case("global", true)
152                                .Case("constant", true)
153                                .Case("local", true)
154                                .Case("generic", true)
155                                .Case("region", true)
156                                .Default(false);
157                          }))
158     return false;
159   if (!verifyScalarEntry(ArgsMap, ".access", false,
160                          msgpack::Type::String,
161                          [](msgpack::DocNode &SNode) {
162                            return StringSwitch<bool>(SNode.getString())
163                                .Case("read_only", true)
164                                .Case("write_only", true)
165                                .Case("read_write", true)
166                                .Default(false);
167                          }))
168     return false;
169   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
170                          msgpack::Type::String,
171                          [](msgpack::DocNode &SNode) {
172                            return StringSwitch<bool>(SNode.getString())
173                                .Case("read_only", true)
174                                .Case("write_only", true)
175                                .Case("read_write", true)
176                                .Default(false);
177                          }))
178     return false;
179   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
180                          msgpack::Type::Boolean))
181     return false;
182   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
183                          msgpack::Type::Boolean))
184     return false;
185   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
186                          msgpack::Type::Boolean))
187     return false;
188   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
189                          msgpack::Type::Boolean))
190     return false;
191 
192   return true;
193 }
194 
195 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
196   if (!Node.isMap())
197     return false;
198   auto &KernelMap = Node.getMap();
199 
200   if (!verifyScalarEntry(KernelMap, ".name", true,
201                          msgpack::Type::String))
202     return false;
203   if (!verifyScalarEntry(KernelMap, ".symbol", true,
204                          msgpack::Type::String))
205     return false;
206   if (!verifyScalarEntry(KernelMap, ".language", false,
207                          msgpack::Type::String,
208                          [](msgpack::DocNode &SNode) {
209                            return StringSwitch<bool>(SNode.getString())
210                                .Case("OpenCL C", true)
211                                .Case("OpenCL C++", true)
212                                .Case("HCC", true)
213                                .Case("HIP", true)
214                                .Case("OpenMP", true)
215                                .Case("Assembler", true)
216                                .Default(false);
217                          }))
218     return false;
219   if (!verifyEntry(
220           KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
221             return verifyArray(
222                 Node,
223                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
224           }))
225     return false;
226   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
227         return verifyArray(Node, [this](msgpack::DocNode &Node) {
228           return verifyKernelArgs(Node);
229         });
230       }))
231     return false;
232   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
233                    [this](msgpack::DocNode &Node) {
234                      return verifyArray(Node,
235                                         [this](msgpack::DocNode &Node) {
236                                           return verifyInteger(Node);
237                                         },
238                                         3);
239                    }))
240     return false;
241   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
242                    [this](msgpack::DocNode &Node) {
243                      return verifyArray(Node,
244                                         [this](msgpack::DocNode &Node) {
245                                           return verifyInteger(Node);
246                                         },
247                                         3);
248                    }))
249     return false;
250   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
251                          msgpack::Type::String))
252     return false;
253   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
254                          msgpack::Type::String))
255     return false;
256   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
257     return false;
258   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
259     return false;
260   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
261     return false;
262   if (!verifyScalarEntry(KernelMap, ".uses_dynamic_stack", false,
263                          msgpack::Type::Boolean))
264     return false;
265   if (!verifyIntegerEntry(KernelMap, ".workgroup_processor_mode", false))
266     return false;
267   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
268     return false;
269   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
270     return false;
271   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
272     return false;
273   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
274     return false;
275   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
276     return false;
277   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
278     return false;
279   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
280     return false;
281   if (!verifyIntegerEntry(KernelMap, ".uniform_work_group_size", false))
282     return false;
283 
284 
285   return true;
286 }
287 
288 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
289   if (!HSAMetadataRoot.isMap())
290     return false;
291   auto &RootMap = HSAMetadataRoot.getMap();
292 
293   if (!verifyEntry(
294           RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
295             return verifyArray(
296                 Node,
297                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
298           }))
299     return false;
300   if (!verifyEntry(
301           RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
302             return verifyArray(Node, [this](msgpack::DocNode &Node) {
303               return verifyScalar(Node, msgpack::Type::String);
304             });
305           }))
306     return false;
307   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
308                    [this](msgpack::DocNode &Node) {
309                      return verifyArray(Node, [this](msgpack::DocNode &Node) {
310                        return verifyKernel(Node);
311                      });
312                    }))
313     return false;
314 
315   return true;
316 }
317 
318 } // end namespace V3
319 } // end namespace HSAMD
320 } // end namespace AMDGPU
321 } // end namespace llvm
322