1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Implements a verifier for AMDGPU HSA metadata. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h" 15 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/ADT/StringSwitch.h" 18 #include "llvm/BinaryFormat/MsgPackDocument.h" 19 20 #include <utility> 21 22 namespace llvm { 23 namespace AMDGPU { 24 namespace HSAMD { 25 namespace V3 { 26 27 bool MetadataVerifier::verifyScalar( 28 msgpack::DocNode &Node, msgpack::Type SKind, 29 function_ref<bool(msgpack::DocNode &)> verifyValue) { 30 if (!Node.isScalar()) 31 return false; 32 if (Node.getKind() != SKind) { 33 if (Strict) 34 return false; 35 // If we are not strict, we interpret string values as "implicitly typed" 36 // and attempt to coerce them to the expected type here. 37 if (Node.getKind() != msgpack::Type::String) 38 return false; 39 StringRef StringValue = Node.getString(); 40 Node.fromString(StringValue); 41 if (Node.getKind() != SKind) 42 return false; 43 } 44 if (verifyValue) 45 return verifyValue(Node); 46 return true; 47 } 48 49 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) { 50 if (!verifyScalar(Node, msgpack::Type::UInt)) 51 if (!verifyScalar(Node, msgpack::Type::Int)) 52 return false; 53 return true; 54 } 55 56 bool MetadataVerifier::verifyArray( 57 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode, 58 std::optional<size_t> Size) { 59 if (!Node.isArray()) 60 return false; 61 auto &Array = Node.getArray(); 62 if (Size && Array.size() != *Size) 63 return false; 64 return llvm::all_of(Array, verifyNode); 65 } 66 67 bool MetadataVerifier::verifyEntry( 68 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 69 function_ref<bool(msgpack::DocNode &)> verifyNode) { 70 auto Entry = MapNode.find(Key); 71 if (Entry == MapNode.end()) 72 return !Required; 73 return verifyNode(Entry->second); 74 } 75 76 bool MetadataVerifier::verifyScalarEntry( 77 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 78 msgpack::Type SKind, 79 function_ref<bool(msgpack::DocNode &)> verifyValue) { 80 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) { 81 return verifyScalar(Node, SKind, verifyValue); 82 }); 83 } 84 85 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode, 86 StringRef Key, bool Required) { 87 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) { 88 return verifyInteger(Node); 89 }); 90 } 91 92 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) { 93 if (!Node.isMap()) 94 return false; 95 auto &ArgsMap = Node.getMap(); 96 97 if (!verifyScalarEntry(ArgsMap, ".name", false, 98 msgpack::Type::String)) 99 return false; 100 if (!verifyScalarEntry(ArgsMap, ".type_name", false, 101 msgpack::Type::String)) 102 return false; 103 if (!verifyIntegerEntry(ArgsMap, ".size", true)) 104 return false; 105 if (!verifyIntegerEntry(ArgsMap, ".offset", true)) 106 return false; 107 if (!verifyScalarEntry(ArgsMap, ".value_kind", true, msgpack::Type::String, 108 [](msgpack::DocNode &SNode) { 109 return StringSwitch<bool>(SNode.getString()) 110 .Case("by_value", true) 111 .Case("global_buffer", true) 112 .Case("dynamic_shared_pointer", true) 113 .Case("sampler", true) 114 .Case("image", true) 115 .Case("pipe", true) 116 .Case("queue", true) 117 .Case("hidden_block_count_x", true) 118 .Case("hidden_block_count_y", true) 119 .Case("hidden_block_count_z", true) 120 .Case("hidden_group_size_x", true) 121 .Case("hidden_group_size_y", true) 122 .Case("hidden_group_size_z", true) 123 .Case("hidden_remainder_x", true) 124 .Case("hidden_remainder_y", true) 125 .Case("hidden_remainder_z", true) 126 .Case("hidden_global_offset_x", true) 127 .Case("hidden_global_offset_y", true) 128 .Case("hidden_global_offset_z", true) 129 .Case("hidden_grid_dims", true) 130 .Case("hidden_none", true) 131 .Case("hidden_printf_buffer", true) 132 .Case("hidden_hostcall_buffer", true) 133 .Case("hidden_heap_v1", true) 134 .Case("hidden_default_queue", true) 135 .Case("hidden_completion_action", true) 136 .Case("hidden_multigrid_sync_arg", true) 137 .Case("hidden_private_base", true) 138 .Case("hidden_shared_base", true) 139 .Case("hidden_queue_ptr", true) 140 .Default(false); 141 })) 142 return false; 143 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false)) 144 return false; 145 if (!verifyScalarEntry(ArgsMap, ".address_space", false, 146 msgpack::Type::String, 147 [](msgpack::DocNode &SNode) { 148 return StringSwitch<bool>(SNode.getString()) 149 .Case("private", true) 150 .Case("global", true) 151 .Case("constant", true) 152 .Case("local", true) 153 .Case("generic", true) 154 .Case("region", true) 155 .Default(false); 156 })) 157 return false; 158 if (!verifyScalarEntry(ArgsMap, ".access", false, 159 msgpack::Type::String, 160 [](msgpack::DocNode &SNode) { 161 return StringSwitch<bool>(SNode.getString()) 162 .Case("read_only", true) 163 .Case("write_only", true) 164 .Case("read_write", true) 165 .Default(false); 166 })) 167 return false; 168 if (!verifyScalarEntry(ArgsMap, ".actual_access", false, 169 msgpack::Type::String, 170 [](msgpack::DocNode &SNode) { 171 return StringSwitch<bool>(SNode.getString()) 172 .Case("read_only", true) 173 .Case("write_only", true) 174 .Case("read_write", true) 175 .Default(false); 176 })) 177 return false; 178 if (!verifyScalarEntry(ArgsMap, ".is_const", false, 179 msgpack::Type::Boolean)) 180 return false; 181 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false, 182 msgpack::Type::Boolean)) 183 return false; 184 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false, 185 msgpack::Type::Boolean)) 186 return false; 187 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false, 188 msgpack::Type::Boolean)) 189 return false; 190 191 return true; 192 } 193 194 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) { 195 if (!Node.isMap()) 196 return false; 197 auto &KernelMap = Node.getMap(); 198 199 if (!verifyScalarEntry(KernelMap, ".name", true, 200 msgpack::Type::String)) 201 return false; 202 if (!verifyScalarEntry(KernelMap, ".symbol", true, 203 msgpack::Type::String)) 204 return false; 205 if (!verifyScalarEntry(KernelMap, ".language", false, 206 msgpack::Type::String, 207 [](msgpack::DocNode &SNode) { 208 return StringSwitch<bool>(SNode.getString()) 209 .Case("OpenCL C", true) 210 .Case("OpenCL C++", true) 211 .Case("HCC", true) 212 .Case("HIP", true) 213 .Case("OpenMP", true) 214 .Case("Assembler", true) 215 .Default(false); 216 })) 217 return false; 218 if (!verifyEntry( 219 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) { 220 return verifyArray( 221 Node, 222 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 223 })) 224 return false; 225 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) { 226 return verifyArray(Node, [this](msgpack::DocNode &Node) { 227 return verifyKernelArgs(Node); 228 }); 229 })) 230 return false; 231 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false, 232 [this](msgpack::DocNode &Node) { 233 return verifyArray(Node, 234 [this](msgpack::DocNode &Node) { 235 return verifyInteger(Node); 236 }, 237 3); 238 })) 239 return false; 240 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false, 241 [this](msgpack::DocNode &Node) { 242 return verifyArray(Node, 243 [this](msgpack::DocNode &Node) { 244 return verifyInteger(Node); 245 }, 246 3); 247 })) 248 return false; 249 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false, 250 msgpack::Type::String)) 251 return false; 252 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false, 253 msgpack::Type::String)) 254 return false; 255 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true)) 256 return false; 257 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true)) 258 return false; 259 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true)) 260 return false; 261 if (!verifyScalarEntry(KernelMap, ".uses_dynamic_stack", false, 262 msgpack::Type::Boolean)) 263 return false; 264 if (!verifyIntegerEntry(KernelMap, ".workgroup_processor_mode", false)) 265 return false; 266 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true)) 267 return false; 268 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true)) 269 return false; 270 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true)) 271 return false; 272 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true)) 273 return false; 274 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true)) 275 return false; 276 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false)) 277 return false; 278 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false)) 279 return false; 280 if (!verifyIntegerEntry(KernelMap, ".uniform_work_group_size", false)) 281 return false; 282 283 284 return true; 285 } 286 287 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) { 288 if (!HSAMetadataRoot.isMap()) 289 return false; 290 auto &RootMap = HSAMetadataRoot.getMap(); 291 292 if (!verifyEntry( 293 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) { 294 return verifyArray( 295 Node, 296 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 297 })) 298 return false; 299 if (!verifyEntry( 300 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) { 301 return verifyArray(Node, [this](msgpack::DocNode &Node) { 302 return verifyScalar(Node, msgpack::Type::String); 303 }); 304 })) 305 return false; 306 if (!verifyEntry(RootMap, "amdhsa.kernels", true, 307 [this](msgpack::DocNode &Node) { 308 return verifyArray(Node, [this](msgpack::DocNode &Node) { 309 return verifyKernel(Node); 310 }); 311 })) 312 return false; 313 314 return true; 315 } 316 317 } // end namespace V3 318 } // end namespace HSAMD 319 } // end namespace AMDGPU 320 } // end namespace llvm 321