1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 #include "llvm/Support/AMDGPUMetadata.h"
16 
17 namespace llvm {
18 namespace AMDGPU {
19 namespace HSAMD {
20 namespace V3 {
21 
22 bool MetadataVerifier::verifyScalar(
23     msgpack::DocNode &Node, msgpack::Type SKind,
24     function_ref<bool(msgpack::DocNode &)> verifyValue) {
25   if (!Node.isScalar())
26     return false;
27   if (Node.getKind() != SKind) {
28     if (Strict)
29       return false;
30     // If we are not strict, we interpret string values as "implicitly typed"
31     // and attempt to coerce them to the expected type here.
32     if (Node.getKind() != msgpack::Type::String)
33       return false;
34     StringRef StringValue = Node.getString();
35     Node.fromString(StringValue);
36     if (Node.getKind() != SKind)
37       return false;
38   }
39   if (verifyValue)
40     return verifyValue(Node);
41   return true;
42 }
43 
44 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
45   if (!verifyScalar(Node, msgpack::Type::UInt))
46     if (!verifyScalar(Node, msgpack::Type::Int))
47       return false;
48   return true;
49 }
50 
51 bool MetadataVerifier::verifyArray(
52     msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
53     Optional<size_t> Size) {
54   if (!Node.isArray())
55     return false;
56   auto &Array = Node.getArray();
57   if (Size && Array.size() != *Size)
58     return false;
59   for (auto &Item : Array)
60     if (!verifyNode(Item))
61       return false;
62 
63   return true;
64 }
65 
66 bool MetadataVerifier::verifyEntry(
67     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
68     function_ref<bool(msgpack::DocNode &)> verifyNode) {
69   auto Entry = MapNode.find(Key);
70   if (Entry == MapNode.end())
71     return !Required;
72   return verifyNode(Entry->second);
73 }
74 
75 bool MetadataVerifier::verifyScalarEntry(
76     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
77     msgpack::Type SKind,
78     function_ref<bool(msgpack::DocNode &)> verifyValue) {
79   return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
80     return verifyScalar(Node, SKind, verifyValue);
81   });
82 }
83 
84 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
85                                           StringRef Key, bool Required) {
86   return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
87     return verifyInteger(Node);
88   });
89 }
90 
91 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
92   if (!Node.isMap())
93     return false;
94   auto &ArgsMap = Node.getMap();
95 
96   if (!verifyScalarEntry(ArgsMap, ".name", false,
97                          msgpack::Type::String))
98     return false;
99   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
100                          msgpack::Type::String))
101     return false;
102   if (!verifyIntegerEntry(ArgsMap, ".size", true))
103     return false;
104   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
105     return false;
106   if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
107                          msgpack::Type::String,
108                          [](msgpack::DocNode &SNode) {
109                            return StringSwitch<bool>(SNode.getString())
110                                .Case("by_value", true)
111                                .Case("global_buffer", true)
112                                .Case("dynamic_shared_pointer", true)
113                                .Case("sampler", true)
114                                .Case("image", true)
115                                .Case("pipe", true)
116                                .Case("queue", true)
117                                .Case("hidden_global_offset_x", true)
118                                .Case("hidden_global_offset_y", true)
119                                .Case("hidden_global_offset_z", true)
120                                .Case("hidden_none", true)
121                                .Case("hidden_printf_buffer", true)
122                                .Case("hidden_hostcall_buffer", true)
123                                .Case("hidden_default_queue", true)
124                                .Case("hidden_completion_action", true)
125                                .Case("hidden_multigrid_sync_arg", true)
126                                .Default(false);
127                          }))
128     return false;
129   if (!verifyScalarEntry(ArgsMap, ".value_type", true,
130                          msgpack::Type::String,
131                          [](msgpack::DocNode &SNode) {
132                            return StringSwitch<bool>(SNode.getString())
133                                .Case("struct", true)
134                                .Case("i8", true)
135                                .Case("u8", true)
136                                .Case("i16", true)
137                                .Case("u16", true)
138                                .Case("f16", true)
139                                .Case("i32", true)
140                                .Case("u32", true)
141                                .Case("f32", true)
142                                .Case("i64", true)
143                                .Case("u64", true)
144                                .Case("f64", true)
145                                .Default(false);
146                          }))
147     return false;
148   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
149     return false;
150   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
151                          msgpack::Type::String,
152                          [](msgpack::DocNode &SNode) {
153                            return StringSwitch<bool>(SNode.getString())
154                                .Case("private", true)
155                                .Case("global", true)
156                                .Case("constant", true)
157                                .Case("local", true)
158                                .Case("generic", true)
159                                .Case("region", true)
160                                .Default(false);
161                          }))
162     return false;
163   if (!verifyScalarEntry(ArgsMap, ".access", false,
164                          msgpack::Type::String,
165                          [](msgpack::DocNode &SNode) {
166                            return StringSwitch<bool>(SNode.getString())
167                                .Case("read_only", true)
168                                .Case("write_only", true)
169                                .Case("read_write", true)
170                                .Default(false);
171                          }))
172     return false;
173   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
174                          msgpack::Type::String,
175                          [](msgpack::DocNode &SNode) {
176                            return StringSwitch<bool>(SNode.getString())
177                                .Case("read_only", true)
178                                .Case("write_only", true)
179                                .Case("read_write", true)
180                                .Default(false);
181                          }))
182     return false;
183   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
184                          msgpack::Type::Boolean))
185     return false;
186   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
187                          msgpack::Type::Boolean))
188     return false;
189   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
190                          msgpack::Type::Boolean))
191     return false;
192   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
193                          msgpack::Type::Boolean))
194     return false;
195 
196   return true;
197 }
198 
199 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
200   if (!Node.isMap())
201     return false;
202   auto &KernelMap = Node.getMap();
203 
204   if (!verifyScalarEntry(KernelMap, ".name", true,
205                          msgpack::Type::String))
206     return false;
207   if (!verifyScalarEntry(KernelMap, ".symbol", true,
208                          msgpack::Type::String))
209     return false;
210   if (!verifyScalarEntry(KernelMap, ".language", false,
211                          msgpack::Type::String,
212                          [](msgpack::DocNode &SNode) {
213                            return StringSwitch<bool>(SNode.getString())
214                                .Case("OpenCL C", true)
215                                .Case("OpenCL C++", true)
216                                .Case("HCC", true)
217                                .Case("HIP", true)
218                                .Case("OpenMP", true)
219                                .Case("Assembler", true)
220                                .Default(false);
221                          }))
222     return false;
223   if (!verifyEntry(
224           KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
225             return verifyArray(
226                 Node,
227                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
228           }))
229     return false;
230   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
231         return verifyArray(Node, [this](msgpack::DocNode &Node) {
232           return verifyKernelArgs(Node);
233         });
234       }))
235     return false;
236   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
237                    [this](msgpack::DocNode &Node) {
238                      return verifyArray(Node,
239                                         [this](msgpack::DocNode &Node) {
240                                           return verifyInteger(Node);
241                                         },
242                                         3);
243                    }))
244     return false;
245   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
246                    [this](msgpack::DocNode &Node) {
247                      return verifyArray(Node,
248                                         [this](msgpack::DocNode &Node) {
249                                           return verifyInteger(Node);
250                                         },
251                                         3);
252                    }))
253     return false;
254   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
255                          msgpack::Type::String))
256     return false;
257   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
258                          msgpack::Type::String))
259     return false;
260   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
261     return false;
262   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
263     return false;
264   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
265     return false;
266   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
267     return false;
268   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
269     return false;
270   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
271     return false;
272   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
273     return false;
274   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
275     return false;
276   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
277     return false;
278   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
279     return false;
280 
281   return true;
282 }
283 
284 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
285   if (!HSAMetadataRoot.isMap())
286     return false;
287   auto &RootMap = HSAMetadataRoot.getMap();
288 
289   if (!verifyEntry(
290           RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
291             return verifyArray(
292                 Node,
293                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
294           }))
295     return false;
296   if (!verifyEntry(
297           RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
298             return verifyArray(Node, [this](msgpack::DocNode &Node) {
299               return verifyScalar(Node, msgpack::Type::String);
300             });
301           }))
302     return false;
303   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
304                    [this](msgpack::DocNode &Node) {
305                      return verifyArray(Node, [this](msgpack::DocNode &Node) {
306                        return verifyKernel(Node);
307                      });
308                    }))
309     return false;
310 
311   return true;
312 }
313 
314 } // end namespace V3
315 } // end namespace HSAMD
316 } // end namespace AMDGPU
317 } // end namespace llvm
318