1 // Copyright (c) 2015-2016 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/val/validation_state.h"
16 
17 #include <cassert>
18 #include <stack>
19 #include <utility>
20 
21 #include "source/opcode.h"
22 #include "source/spirv_constant.h"
23 #include "source/spirv_target_env.h"
24 #include "source/val/basic_block.h"
25 #include "source/val/construct.h"
26 #include "source/val/function.h"
27 #include "spirv-tools/libspirv.h"
28 
29 namespace spvtools {
30 namespace val {
31 namespace {
32 
InstructionLayoutSection(ModuleLayoutSection current_section,SpvOp op)33 ModuleLayoutSection InstructionLayoutSection(
34     ModuleLayoutSection current_section, SpvOp op) {
35   // See Section 2.4
36   if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op))
37     return kLayoutTypes;
38 
39   switch (op) {
40     case SpvOpCapability:
41       return kLayoutCapabilities;
42     case SpvOpExtension:
43       return kLayoutExtensions;
44     case SpvOpExtInstImport:
45       return kLayoutExtInstImport;
46     case SpvOpMemoryModel:
47       return kLayoutMemoryModel;
48     case SpvOpEntryPoint:
49       return kLayoutEntryPoint;
50     case SpvOpExecutionMode:
51     case SpvOpExecutionModeId:
52       return kLayoutExecutionMode;
53     case SpvOpSourceContinued:
54     case SpvOpSource:
55     case SpvOpSourceExtension:
56     case SpvOpString:
57       return kLayoutDebug1;
58     case SpvOpName:
59     case SpvOpMemberName:
60       return kLayoutDebug2;
61     case SpvOpModuleProcessed:
62       return kLayoutDebug3;
63     case SpvOpDecorate:
64     case SpvOpMemberDecorate:
65     case SpvOpGroupDecorate:
66     case SpvOpGroupMemberDecorate:
67     case SpvOpDecorationGroup:
68     case SpvOpDecorateId:
69     case SpvOpDecorateStringGOOGLE:
70     case SpvOpMemberDecorateStringGOOGLE:
71       return kLayoutAnnotations;
72     case SpvOpTypeForwardPointer:
73       return kLayoutTypes;
74     case SpvOpVariable:
75       if (current_section == kLayoutTypes) return kLayoutTypes;
76       return kLayoutFunctionDefinitions;
77     case SpvOpExtInst:
78       // SpvOpExtInst is only allowed in types section for certain extended
79       // instruction sets. This will be checked separately.
80       if (current_section == kLayoutTypes) return kLayoutTypes;
81       return kLayoutFunctionDefinitions;
82     case SpvOpLine:
83     case SpvOpNoLine:
84     case SpvOpUndef:
85       if (current_section == kLayoutTypes) return kLayoutTypes;
86       return kLayoutFunctionDefinitions;
87     case SpvOpFunction:
88     case SpvOpFunctionParameter:
89     case SpvOpFunctionEnd:
90       if (current_section == kLayoutFunctionDeclarations)
91         return kLayoutFunctionDeclarations;
92       return kLayoutFunctionDefinitions;
93     default:
94       break;
95   }
96   return kLayoutFunctionDefinitions;
97 }
98 
IsInstructionInLayoutSection(ModuleLayoutSection layout,SpvOp op)99 bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) {
100   return layout == InstructionLayoutSection(layout, op);
101 }
102 
103 // Counts the number of instructions and functions in the file.
CountInstructions(void * user_data,const spv_parsed_instruction_t * inst)104 spv_result_t CountInstructions(void* user_data,
105                                const spv_parsed_instruction_t* inst) {
106   ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
107   if (inst->opcode == SpvOpFunction) _.increment_total_functions();
108   _.increment_total_instructions();
109 
110   return SPV_SUCCESS;
111 }
112 
setHeader(void * user_data,spv_endianness_t,uint32_t,uint32_t version,uint32_t generator,uint32_t id_bound,uint32_t)113 spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t,
114                        uint32_t version, uint32_t generator, uint32_t id_bound,
115                        uint32_t) {
116   ValidationState_t& vstate =
117       *(reinterpret_cast<ValidationState_t*>(user_data));
118   vstate.setIdBound(id_bound);
119   vstate.setGenerator(generator);
120   vstate.setVersion(version);
121 
122   return SPV_SUCCESS;
123 }
124 
125 // Add features based on SPIR-V core version number.
UpdateFeaturesBasedOnSpirvVersion(ValidationState_t::Feature * features,uint32_t version)126 void UpdateFeaturesBasedOnSpirvVersion(ValidationState_t::Feature* features,
127                                        uint32_t version) {
128   assert(features);
129   if (version >= SPV_SPIRV_VERSION_WORD(1, 4)) {
130     features->select_between_composites = true;
131     features->copy_memory_permits_two_memory_accesses = true;
132     features->uconvert_spec_constant_op = true;
133     features->nonwritable_var_in_function_or_private = true;
134   }
135 }
136 
137 }  // namespace
138 
ValidationState_t(const spv_const_context ctx,const spv_const_validator_options opt,const uint32_t * words,const size_t num_words,const uint32_t max_warnings)139 ValidationState_t::ValidationState_t(const spv_const_context ctx,
140                                      const spv_const_validator_options opt,
141                                      const uint32_t* words,
142                                      const size_t num_words,
143                                      const uint32_t max_warnings)
144     : context_(ctx),
145       options_(opt),
146       words_(words),
147       num_words_(num_words),
148       unresolved_forward_ids_{},
149       operand_names_{},
150       current_layout_section_(kLayoutCapabilities),
151       module_functions_(),
152       module_capabilities_(),
153       module_extensions_(),
154       ordered_instructions_(),
155       all_definitions_(),
156       global_vars_(),
157       local_vars_(),
158       struct_nesting_depth_(),
159       struct_has_nested_blockorbufferblock_struct_(),
160       grammar_(ctx),
161       addressing_model_(SpvAddressingModelMax),
162       memory_model_(SpvMemoryModelMax),
163       pointer_size_and_alignment_(0),
164       in_function_(false),
165       num_of_warnings_(0),
166       max_num_of_warnings_(max_warnings) {
167   assert(opt && "Validator options may not be Null.");
168 
169   const auto env = context_->target_env;
170 
171   if (spvIsVulkanEnv(env)) {
172     // Vulkan 1.1 includes VK_KHR_relaxed_block_layout in core.
173     if (env != SPV_ENV_VULKAN_1_0) {
174       features_.env_relaxed_block_layout = true;
175     }
176   }
177 
178   // Only attempt to count if we have words, otherwise let the other validation
179   // fail and generate an error.
180   if (num_words > 0) {
181     // Count the number of instructions in the binary.
182     // This parse should not produce any error messages. Hijack the context and
183     // replace the message consumer so that we do not pollute any state in input
184     // consumer.
185     spv_context_t hijacked_context = *ctx;
186     hijacked_context.consumer = [](spv_message_level_t, const char*,
__anonbbd788f90202(spv_message_level_t, const char*, const spv_position_t&, const char*) 187                                    const spv_position_t&, const char*) {};
188     spvBinaryParse(&hijacked_context, this, words, num_words, setHeader,
189                    CountInstructions,
190                    /* diagnostic = */ nullptr);
191     preallocateStorage();
192   }
193   UpdateFeaturesBasedOnSpirvVersion(&features_, version_);
194 
195   friendly_mapper_ = spvtools::MakeUnique<spvtools::FriendlyNameMapper>(
196       context_, words_, num_words_);
197   name_mapper_ = friendly_mapper_->GetNameMapper();
198 }
199 
preallocateStorage()200 void ValidationState_t::preallocateStorage() {
201   ordered_instructions_.reserve(total_instructions_);
202   module_functions_.reserve(total_functions_);
203 }
204 
ForwardDeclareId(uint32_t id)205 spv_result_t ValidationState_t::ForwardDeclareId(uint32_t id) {
206   unresolved_forward_ids_.insert(id);
207   return SPV_SUCCESS;
208 }
209 
RemoveIfForwardDeclared(uint32_t id)210 spv_result_t ValidationState_t::RemoveIfForwardDeclared(uint32_t id) {
211   unresolved_forward_ids_.erase(id);
212   return SPV_SUCCESS;
213 }
214 
RegisterForwardPointer(uint32_t id)215 spv_result_t ValidationState_t::RegisterForwardPointer(uint32_t id) {
216   forward_pointer_ids_.insert(id);
217   return SPV_SUCCESS;
218 }
219 
IsForwardPointer(uint32_t id) const220 bool ValidationState_t::IsForwardPointer(uint32_t id) const {
221   return (forward_pointer_ids_.find(id) != forward_pointer_ids_.end());
222 }
223 
AssignNameToId(uint32_t id,std::string name)224 void ValidationState_t::AssignNameToId(uint32_t id, std::string name) {
225   operand_names_[id] = name;
226 }
227 
getIdName(uint32_t id) const228 std::string ValidationState_t::getIdName(uint32_t id) const {
229   const std::string id_name = name_mapper_(id);
230 
231   std::stringstream out;
232   out << id << "[%" << id_name << "]";
233   return out.str();
234 }
235 
unresolved_forward_id_count() const236 size_t ValidationState_t::unresolved_forward_id_count() const {
237   return unresolved_forward_ids_.size();
238 }
239 
UnresolvedForwardIds() const240 std::vector<uint32_t> ValidationState_t::UnresolvedForwardIds() const {
241   std::vector<uint32_t> out(std::begin(unresolved_forward_ids_),
242                             std::end(unresolved_forward_ids_));
243   return out;
244 }
245 
IsDefinedId(uint32_t id) const246 bool ValidationState_t::IsDefinedId(uint32_t id) const {
247   return all_definitions_.find(id) != std::end(all_definitions_);
248 }
249 
FindDef(uint32_t id) const250 const Instruction* ValidationState_t::FindDef(uint32_t id) const {
251   auto it = all_definitions_.find(id);
252   if (it == all_definitions_.end()) return nullptr;
253   return it->second;
254 }
255 
FindDef(uint32_t id)256 Instruction* ValidationState_t::FindDef(uint32_t id) {
257   auto it = all_definitions_.find(id);
258   if (it == all_definitions_.end()) return nullptr;
259   return it->second;
260 }
261 
current_layout_section() const262 ModuleLayoutSection ValidationState_t::current_layout_section() const {
263   return current_layout_section_;
264 }
265 
ProgressToNextLayoutSectionOrder()266 void ValidationState_t::ProgressToNextLayoutSectionOrder() {
267   // Guard against going past the last element(kLayoutFunctionDefinitions)
268   if (current_layout_section_ <= kLayoutFunctionDefinitions) {
269     current_layout_section_ =
270         static_cast<ModuleLayoutSection>(current_layout_section_ + 1);
271   }
272 }
273 
IsOpcodeInPreviousLayoutSection(SpvOp op)274 bool ValidationState_t::IsOpcodeInPreviousLayoutSection(SpvOp op) {
275   ModuleLayoutSection section =
276       InstructionLayoutSection(current_layout_section_, op);
277   return section < current_layout_section_;
278 }
279 
IsOpcodeInCurrentLayoutSection(SpvOp op)280 bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) {
281   return IsInstructionInLayoutSection(current_layout_section_, op);
282 }
283 
diag(spv_result_t error_code,const Instruction * inst)284 DiagnosticStream ValidationState_t::diag(spv_result_t error_code,
285                                          const Instruction* inst) {
286   if (error_code == SPV_WARNING) {
287     if (num_of_warnings_ == max_num_of_warnings_) {
288       DiagnosticStream({0, 0, 0}, context_->consumer, "", error_code)
289           << "Other warnings have been suppressed.\n";
290     }
291     if (num_of_warnings_ >= max_num_of_warnings_) {
292       return DiagnosticStream({0, 0, 0}, nullptr, "", error_code);
293     }
294     ++num_of_warnings_;
295   }
296 
297   std::string disassembly;
298   if (inst) disassembly = Disassemble(*inst);
299 
300   return DiagnosticStream({0, 0, inst ? inst->LineNum() : 0},
301                           context_->consumer, disassembly, error_code);
302 }
303 
functions()304 std::vector<Function>& ValidationState_t::functions() {
305   return module_functions_;
306 }
307 
current_function()308 Function& ValidationState_t::current_function() {
309   assert(in_function_body());
310   return module_functions_.back();
311 }
312 
current_function() const313 const Function& ValidationState_t::current_function() const {
314   assert(in_function_body());
315   return module_functions_.back();
316 }
317 
function(uint32_t id) const318 const Function* ValidationState_t::function(uint32_t id) const {
319   const auto it = id_to_function_.find(id);
320   if (it == id_to_function_.end()) return nullptr;
321   return it->second;
322 }
323 
function(uint32_t id)324 Function* ValidationState_t::function(uint32_t id) {
325   auto it = id_to_function_.find(id);
326   if (it == id_to_function_.end()) return nullptr;
327   return it->second;
328 }
329 
in_function_body() const330 bool ValidationState_t::in_function_body() const { return in_function_; }
331 
in_block() const332 bool ValidationState_t::in_block() const {
333   return module_functions_.empty() == false &&
334          module_functions_.back().current_block() != nullptr;
335 }
336 
RegisterCapability(SpvCapability cap)337 void ValidationState_t::RegisterCapability(SpvCapability cap) {
338   // Avoid redundant work.  Otherwise the recursion could induce work
339   // quadrdatic in the capability dependency depth. (Ok, not much, but
340   // it's something.)
341   if (module_capabilities_.Contains(cap)) return;
342 
343   module_capabilities_.Add(cap);
344   spv_operand_desc desc;
345   if (SPV_SUCCESS ==
346       grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) {
347     CapabilitySet(desc->numCapabilities, desc->capabilities)
348         .ForEach([this](SpvCapability c) { RegisterCapability(c); });
349   }
350 
351   switch (cap) {
352     case SpvCapabilityKernel:
353       features_.group_ops_reduce_and_scans = true;
354       break;
355     case SpvCapabilityInt8:
356       features_.use_int8_type = true;
357       features_.declare_int8_type = true;
358       break;
359     case SpvCapabilityStorageBuffer8BitAccess:
360     case SpvCapabilityUniformAndStorageBuffer8BitAccess:
361     case SpvCapabilityStoragePushConstant8:
362       features_.declare_int8_type = true;
363       break;
364     case SpvCapabilityInt16:
365       features_.declare_int16_type = true;
366       break;
367     case SpvCapabilityFloat16:
368     case SpvCapabilityFloat16Buffer:
369       features_.declare_float16_type = true;
370       break;
371     case SpvCapabilityStorageUniformBufferBlock16:
372     case SpvCapabilityStorageUniform16:
373     case SpvCapabilityStoragePushConstant16:
374     case SpvCapabilityStorageInputOutput16:
375       features_.declare_int16_type = true;
376       features_.declare_float16_type = true;
377       features_.free_fp_rounding_mode = true;
378       break;
379     case SpvCapabilityVariablePointers:
380       features_.variable_pointers = true;
381       features_.variable_pointers_storage_buffer = true;
382       break;
383     case SpvCapabilityVariablePointersStorageBuffer:
384       features_.variable_pointers_storage_buffer = true;
385       break;
386     default:
387       break;
388   }
389 }
390 
RegisterExtension(Extension ext)391 void ValidationState_t::RegisterExtension(Extension ext) {
392   if (module_extensions_.Contains(ext)) return;
393 
394   module_extensions_.Add(ext);
395 
396   switch (ext) {
397     case kSPV_AMD_gpu_shader_half_float:
398     case kSPV_AMD_gpu_shader_half_float_fetch:
399       // SPV_AMD_gpu_shader_half_float enables float16 type.
400       // https://github.com/KhronosGroup/SPIRV-Tools/issues/1375
401       features_.declare_float16_type = true;
402       break;
403     case kSPV_AMD_gpu_shader_int16:
404       // This is not yet in the extension, but it's recommended for it.
405       // See https://github.com/KhronosGroup/glslang/issues/848
406       features_.uconvert_spec_constant_op = true;
407       break;
408     case kSPV_AMD_shader_ballot:
409       // The grammar doesn't encode the fact that SPV_AMD_shader_ballot
410       // enables the use of group operations Reduce, InclusiveScan,
411       // and ExclusiveScan.  Enable it manually.
412       // https://github.com/KhronosGroup/SPIRV-Tools/issues/991
413       features_.group_ops_reduce_and_scans = true;
414       break;
415     default:
416       break;
417   }
418 }
419 
HasAnyOfCapabilities(const CapabilitySet & capabilities) const420 bool ValidationState_t::HasAnyOfCapabilities(
421     const CapabilitySet& capabilities) const {
422   return module_capabilities_.HasAnyOf(capabilities);
423 }
424 
HasAnyOfExtensions(const ExtensionSet & extensions) const425 bool ValidationState_t::HasAnyOfExtensions(
426     const ExtensionSet& extensions) const {
427   return module_extensions_.HasAnyOf(extensions);
428 }
429 
set_addressing_model(SpvAddressingModel am)430 void ValidationState_t::set_addressing_model(SpvAddressingModel am) {
431   addressing_model_ = am;
432   switch (am) {
433     case SpvAddressingModelPhysical32:
434       pointer_size_and_alignment_ = 4;
435       break;
436     default:
437     // fall through
438     case SpvAddressingModelPhysical64:
439     case SpvAddressingModelPhysicalStorageBuffer64EXT:
440       pointer_size_and_alignment_ = 8;
441       break;
442   }
443 }
444 
addressing_model() const445 SpvAddressingModel ValidationState_t::addressing_model() const {
446   return addressing_model_;
447 }
448 
set_memory_model(SpvMemoryModel mm)449 void ValidationState_t::set_memory_model(SpvMemoryModel mm) {
450   memory_model_ = mm;
451 }
452 
memory_model() const453 SpvMemoryModel ValidationState_t::memory_model() const { return memory_model_; }
454 
RegisterFunction(uint32_t id,uint32_t ret_type_id,SpvFunctionControlMask function_control,uint32_t function_type_id)455 spv_result_t ValidationState_t::RegisterFunction(
456     uint32_t id, uint32_t ret_type_id, SpvFunctionControlMask function_control,
457     uint32_t function_type_id) {
458   assert(in_function_body() == false &&
459          "RegisterFunction can only be called when parsing the binary outside "
460          "of another function");
461   in_function_ = true;
462   module_functions_.emplace_back(id, ret_type_id, function_control,
463                                  function_type_id);
464   id_to_function_.emplace(id, &current_function());
465 
466   // TODO(umar): validate function type and type_id
467 
468   return SPV_SUCCESS;
469 }
470 
RegisterFunctionEnd()471 spv_result_t ValidationState_t::RegisterFunctionEnd() {
472   assert(in_function_body() == true &&
473          "RegisterFunctionEnd can only be called when parsing the binary "
474          "inside of another function");
475   assert(in_block() == false &&
476          "RegisterFunctionParameter can only be called when parsing the binary "
477          "ouside of a block");
478   current_function().RegisterFunctionEnd();
479   in_function_ = false;
480   return SPV_SUCCESS;
481 }
482 
AddOrderedInstruction(const spv_parsed_instruction_t * inst)483 Instruction* ValidationState_t::AddOrderedInstruction(
484     const spv_parsed_instruction_t* inst) {
485   ordered_instructions_.emplace_back(inst);
486   ordered_instructions_.back().SetLineNum(ordered_instructions_.size());
487   return &ordered_instructions_.back();
488 }
489 
490 // Improves diagnostic messages by collecting names of IDs
RegisterDebugInstruction(const Instruction * inst)491 void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) {
492   switch (inst->opcode()) {
493     case SpvOpName: {
494       const auto target = inst->GetOperandAs<uint32_t>(0);
495       const auto* str = reinterpret_cast<const char*>(inst->words().data() +
496                                                       inst->operand(1).offset);
497       AssignNameToId(target, str);
498       break;
499     }
500     case SpvOpMemberName: {
501       const auto target = inst->GetOperandAs<uint32_t>(0);
502       const auto* str = reinterpret_cast<const char*>(inst->words().data() +
503                                                       inst->operand(2).offset);
504       AssignNameToId(target, str);
505       break;
506     }
507     case SpvOpSourceContinued:
508     case SpvOpSource:
509     case SpvOpSourceExtension:
510     case SpvOpString:
511     case SpvOpLine:
512     case SpvOpNoLine:
513     default:
514       break;
515   }
516 }
517 
RegisterInstruction(Instruction * inst)518 void ValidationState_t::RegisterInstruction(Instruction* inst) {
519   if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst));
520 
521   // If the instruction is using an OpTypeSampledImage as an operand, it should
522   // be recorded. The validator will ensure that all usages of an
523   // OpTypeSampledImage and its definition are in the same basic block.
524   for (uint16_t i = 0; i < inst->operands().size(); ++i) {
525     const spv_parsed_operand_t& operand = inst->operand(i);
526     if (SPV_OPERAND_TYPE_ID == operand.type) {
527       const uint32_t operand_word = inst->word(operand.offset);
528       Instruction* operand_inst = FindDef(operand_word);
529       if (operand_inst && SpvOpSampledImage == operand_inst->opcode()) {
530         RegisterSampledImageConsumer(operand_word, inst);
531       }
532     }
533   }
534 }
535 
getSampledImageConsumers(uint32_t sampled_image_id) const536 std::vector<Instruction*> ValidationState_t::getSampledImageConsumers(
537     uint32_t sampled_image_id) const {
538   std::vector<Instruction*> result;
539   auto iter = sampled_image_consumers_.find(sampled_image_id);
540   if (iter != sampled_image_consumers_.end()) {
541     result = iter->second;
542   }
543   return result;
544 }
545 
RegisterSampledImageConsumer(uint32_t sampled_image_id,Instruction * consumer)546 void ValidationState_t::RegisterSampledImageConsumer(uint32_t sampled_image_id,
547                                                      Instruction* consumer) {
548   sampled_image_consumers_[sampled_image_id].push_back(consumer);
549 }
550 
getIdBound() const551 uint32_t ValidationState_t::getIdBound() const { return id_bound_; }
552 
setIdBound(const uint32_t bound)553 void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; }
554 
RegisterUniqueTypeDeclaration(const Instruction * inst)555 bool ValidationState_t::RegisterUniqueTypeDeclaration(const Instruction* inst) {
556   std::vector<uint32_t> key;
557   key.push_back(static_cast<uint32_t>(inst->opcode()));
558   for (size_t index = 0; index < inst->operands().size(); ++index) {
559     const spv_parsed_operand_t& operand = inst->operand(index);
560 
561     if (operand.type == SPV_OPERAND_TYPE_RESULT_ID) continue;
562 
563     const int words_begin = operand.offset;
564     const int words_end = words_begin + operand.num_words;
565     assert(words_end <= static_cast<int>(inst->words().size()));
566 
567     key.insert(key.end(), inst->words().begin() + words_begin,
568                inst->words().begin() + words_end);
569   }
570 
571   return unique_type_declarations_.insert(std::move(key)).second;
572 }
573 
GetTypeId(uint32_t id) const574 uint32_t ValidationState_t::GetTypeId(uint32_t id) const {
575   const Instruction* inst = FindDef(id);
576   return inst ? inst->type_id() : 0;
577 }
578 
GetIdOpcode(uint32_t id) const579 SpvOp ValidationState_t::GetIdOpcode(uint32_t id) const {
580   const Instruction* inst = FindDef(id);
581   return inst ? inst->opcode() : SpvOpNop;
582 }
583 
GetComponentType(uint32_t id) const584 uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
585   const Instruction* inst = FindDef(id);
586   assert(inst);
587 
588   switch (inst->opcode()) {
589     case SpvOpTypeFloat:
590     case SpvOpTypeInt:
591     case SpvOpTypeBool:
592       return id;
593 
594     case SpvOpTypeVector:
595       return inst->word(2);
596 
597     case SpvOpTypeMatrix:
598       return GetComponentType(inst->word(2));
599 
600     case SpvOpTypeCooperativeMatrixNV:
601       return inst->word(2);
602 
603     default:
604       break;
605   }
606 
607   if (inst->type_id()) return GetComponentType(inst->type_id());
608 
609   assert(0);
610   return 0;
611 }
612 
GetDimension(uint32_t id) const613 uint32_t ValidationState_t::GetDimension(uint32_t id) const {
614   const Instruction* inst = FindDef(id);
615   assert(inst);
616 
617   switch (inst->opcode()) {
618     case SpvOpTypeFloat:
619     case SpvOpTypeInt:
620     case SpvOpTypeBool:
621       return 1;
622 
623     case SpvOpTypeVector:
624     case SpvOpTypeMatrix:
625       return inst->word(3);
626 
627     case SpvOpTypeCooperativeMatrixNV:
628       // Actual dimension isn't known, return 0
629       return 0;
630 
631     default:
632       break;
633   }
634 
635   if (inst->type_id()) return GetDimension(inst->type_id());
636 
637   assert(0);
638   return 0;
639 }
640 
GetBitWidth(uint32_t id) const641 uint32_t ValidationState_t::GetBitWidth(uint32_t id) const {
642   const uint32_t component_type_id = GetComponentType(id);
643   const Instruction* inst = FindDef(component_type_id);
644   assert(inst);
645 
646   if (inst->opcode() == SpvOpTypeFloat || inst->opcode() == SpvOpTypeInt)
647     return inst->word(2);
648 
649   if (inst->opcode() == SpvOpTypeBool) return 1;
650 
651   assert(0);
652   return 0;
653 }
654 
IsVoidType(uint32_t id) const655 bool ValidationState_t::IsVoidType(uint32_t id) const {
656   const Instruction* inst = FindDef(id);
657   assert(inst);
658   return inst->opcode() == SpvOpTypeVoid;
659 }
660 
IsFloatScalarType(uint32_t id) const661 bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
662   const Instruction* inst = FindDef(id);
663   assert(inst);
664   return inst->opcode() == SpvOpTypeFloat;
665 }
666 
IsFloatVectorType(uint32_t id) const667 bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
668   const Instruction* inst = FindDef(id);
669   assert(inst);
670 
671   if (inst->opcode() == SpvOpTypeVector) {
672     return IsFloatScalarType(GetComponentType(id));
673   }
674 
675   return false;
676 }
677 
IsFloatScalarOrVectorType(uint32_t id) const678 bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
679   const Instruction* inst = FindDef(id);
680   assert(inst);
681 
682   if (inst->opcode() == SpvOpTypeFloat) {
683     return true;
684   }
685 
686   if (inst->opcode() == SpvOpTypeVector) {
687     return IsFloatScalarType(GetComponentType(id));
688   }
689 
690   return false;
691 }
692 
IsIntScalarType(uint32_t id) const693 bool ValidationState_t::IsIntScalarType(uint32_t id) const {
694   const Instruction* inst = FindDef(id);
695   assert(inst);
696   return inst->opcode() == SpvOpTypeInt;
697 }
698 
IsIntVectorType(uint32_t id) const699 bool ValidationState_t::IsIntVectorType(uint32_t id) const {
700   const Instruction* inst = FindDef(id);
701   assert(inst);
702 
703   if (inst->opcode() == SpvOpTypeVector) {
704     return IsIntScalarType(GetComponentType(id));
705   }
706 
707   return false;
708 }
709 
IsIntScalarOrVectorType(uint32_t id) const710 bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const {
711   const Instruction* inst = FindDef(id);
712   assert(inst);
713 
714   if (inst->opcode() == SpvOpTypeInt) {
715     return true;
716   }
717 
718   if (inst->opcode() == SpvOpTypeVector) {
719     return IsIntScalarType(GetComponentType(id));
720   }
721 
722   return false;
723 }
724 
IsUnsignedIntScalarType(uint32_t id) const725 bool ValidationState_t::IsUnsignedIntScalarType(uint32_t id) const {
726   const Instruction* inst = FindDef(id);
727   assert(inst);
728   return inst->opcode() == SpvOpTypeInt && inst->word(3) == 0;
729 }
730 
IsUnsignedIntVectorType(uint32_t id) const731 bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
732   const Instruction* inst = FindDef(id);
733   assert(inst);
734 
735   if (inst->opcode() == SpvOpTypeVector) {
736     return IsUnsignedIntScalarType(GetComponentType(id));
737   }
738 
739   return false;
740 }
741 
IsSignedIntScalarType(uint32_t id) const742 bool ValidationState_t::IsSignedIntScalarType(uint32_t id) const {
743   const Instruction* inst = FindDef(id);
744   assert(inst);
745   return inst->opcode() == SpvOpTypeInt && inst->word(3) == 1;
746 }
747 
IsSignedIntVectorType(uint32_t id) const748 bool ValidationState_t::IsSignedIntVectorType(uint32_t id) const {
749   const Instruction* inst = FindDef(id);
750   assert(inst);
751 
752   if (inst->opcode() == SpvOpTypeVector) {
753     return IsSignedIntScalarType(GetComponentType(id));
754   }
755 
756   return false;
757 }
758 
IsBoolScalarType(uint32_t id) const759 bool ValidationState_t::IsBoolScalarType(uint32_t id) const {
760   const Instruction* inst = FindDef(id);
761   assert(inst);
762   return inst->opcode() == SpvOpTypeBool;
763 }
764 
IsBoolVectorType(uint32_t id) const765 bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
766   const Instruction* inst = FindDef(id);
767   assert(inst);
768 
769   if (inst->opcode() == SpvOpTypeVector) {
770     return IsBoolScalarType(GetComponentType(id));
771   }
772 
773   return false;
774 }
775 
IsBoolScalarOrVectorType(uint32_t id) const776 bool ValidationState_t::IsBoolScalarOrVectorType(uint32_t id) const {
777   const Instruction* inst = FindDef(id);
778   assert(inst);
779 
780   if (inst->opcode() == SpvOpTypeBool) {
781     return true;
782   }
783 
784   if (inst->opcode() == SpvOpTypeVector) {
785     return IsBoolScalarType(GetComponentType(id));
786   }
787 
788   return false;
789 }
790 
IsFloatMatrixType(uint32_t id) const791 bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
792   const Instruction* inst = FindDef(id);
793   assert(inst);
794 
795   if (inst->opcode() == SpvOpTypeMatrix) {
796     return IsFloatScalarType(GetComponentType(id));
797   }
798 
799   return false;
800 }
801 
GetMatrixTypeInfo(uint32_t id,uint32_t * num_rows,uint32_t * num_cols,uint32_t * column_type,uint32_t * component_type) const802 bool ValidationState_t::GetMatrixTypeInfo(uint32_t id, uint32_t* num_rows,
803                                           uint32_t* num_cols,
804                                           uint32_t* column_type,
805                                           uint32_t* component_type) const {
806   if (!id) return false;
807 
808   const Instruction* mat_inst = FindDef(id);
809   assert(mat_inst);
810   if (mat_inst->opcode() != SpvOpTypeMatrix) return false;
811 
812   const uint32_t vec_type = mat_inst->word(2);
813   const Instruction* vec_inst = FindDef(vec_type);
814   assert(vec_inst);
815 
816   if (vec_inst->opcode() != SpvOpTypeVector) {
817     assert(0);
818     return false;
819   }
820 
821   *num_cols = mat_inst->word(3);
822   *num_rows = vec_inst->word(3);
823   *column_type = mat_inst->word(2);
824   *component_type = vec_inst->word(2);
825 
826   return true;
827 }
828 
GetStructMemberTypes(uint32_t struct_type_id,std::vector<uint32_t> * member_types) const829 bool ValidationState_t::GetStructMemberTypes(
830     uint32_t struct_type_id, std::vector<uint32_t>* member_types) const {
831   member_types->clear();
832   if (!struct_type_id) return false;
833 
834   const Instruction* inst = FindDef(struct_type_id);
835   assert(inst);
836   if (inst->opcode() != SpvOpTypeStruct) return false;
837 
838   *member_types =
839       std::vector<uint32_t>(inst->words().cbegin() + 2, inst->words().cend());
840 
841   if (member_types->empty()) return false;
842 
843   return true;
844 }
845 
IsPointerType(uint32_t id) const846 bool ValidationState_t::IsPointerType(uint32_t id) const {
847   const Instruction* inst = FindDef(id);
848   assert(inst);
849   return inst->opcode() == SpvOpTypePointer;
850 }
851 
GetPointerTypeInfo(uint32_t id,uint32_t * data_type,uint32_t * storage_class) const852 bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
853                                            uint32_t* storage_class) const {
854   if (!id) return false;
855 
856   const Instruction* inst = FindDef(id);
857   assert(inst);
858   if (inst->opcode() != SpvOpTypePointer) return false;
859 
860   *storage_class = inst->word(2);
861   *data_type = inst->word(3);
862   return true;
863 }
864 
IsCooperativeMatrixType(uint32_t id) const865 bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
866   const Instruction* inst = FindDef(id);
867   assert(inst);
868   return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
869 }
870 
IsFloatCooperativeMatrixType(uint32_t id) const871 bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
872   if (!IsCooperativeMatrixType(id)) return false;
873   return IsFloatScalarType(FindDef(id)->word(2));
874 }
875 
IsIntCooperativeMatrixType(uint32_t id) const876 bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
877   if (!IsCooperativeMatrixType(id)) return false;
878   return IsIntScalarType(FindDef(id)->word(2));
879 }
880 
IsUnsignedIntCooperativeMatrixType(uint32_t id) const881 bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
882   if (!IsCooperativeMatrixType(id)) return false;
883   return IsUnsignedIntScalarType(FindDef(id)->word(2));
884 }
885 
CooperativeMatrixShapesMatch(const Instruction * inst,uint32_t m1,uint32_t m2)886 spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
887     const Instruction* inst, uint32_t m1, uint32_t m2) {
888   const auto m1_type = FindDef(m1);
889   const auto m2_type = FindDef(m2);
890 
891   if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
892       m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
893     return diag(SPV_ERROR_INVALID_DATA, inst)
894            << "Expected cooperative matrix types";
895   }
896 
897   uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
898   uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
899   uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
900 
901   uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
902   uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
903   uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
904 
905   bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
906        m2_is_const_int32 = false;
907   uint32_t m1_value = 0, m2_value = 0;
908 
909   std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
910       EvalInt32IfConst(m1_scope_id);
911   std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
912       EvalInt32IfConst(m2_scope_id);
913 
914   if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
915     return diag(SPV_ERROR_INVALID_DATA, inst)
916            << "Expected scopes of Matrix and Result Type to be "
917            << "identical";
918   }
919 
920   std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
921       EvalInt32IfConst(m1_rows_id);
922   std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
923       EvalInt32IfConst(m2_rows_id);
924 
925   if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
926     return diag(SPV_ERROR_INVALID_DATA, inst)
927            << "Expected rows of Matrix type and Result Type to be "
928            << "identical";
929   }
930 
931   std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
932       EvalInt32IfConst(m1_cols_id);
933   std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
934       EvalInt32IfConst(m2_cols_id);
935 
936   if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
937     return diag(SPV_ERROR_INVALID_DATA, inst)
938            << "Expected columns of Matrix type and Result Type to be "
939            << "identical";
940   }
941 
942   return SPV_SUCCESS;
943 }
944 
GetOperandTypeId(const Instruction * inst,size_t operand_index) const945 uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
946                                              size_t operand_index) const {
947   return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
948 }
949 
GetConstantValUint64(uint32_t id,uint64_t * val) const950 bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
951   const Instruction* inst = FindDef(id);
952   if (!inst) {
953     assert(0 && "Instruction not found");
954     return false;
955   }
956 
957   if (inst->opcode() != SpvOpConstant && inst->opcode() != SpvOpSpecConstant)
958     return false;
959 
960   if (!IsIntScalarType(inst->type_id())) return false;
961 
962   if (inst->words().size() == 4) {
963     *val = inst->word(3);
964   } else {
965     assert(inst->words().size() == 5);
966     *val = inst->word(3);
967     *val |= uint64_t(inst->word(4)) << 32;
968   }
969   return true;
970 }
971 
EvalInt32IfConst(uint32_t id) const972 std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
973     uint32_t id) const {
974   const Instruction* const inst = FindDef(id);
975   assert(inst);
976   const uint32_t type = inst->type_id();
977 
978   if (type == 0 || !IsIntScalarType(type) || GetBitWidth(type) != 32) {
979     return std::make_tuple(false, false, 0);
980   }
981 
982   // Spec constant values cannot be evaluated so don't consider constant for
983   // the purpose of this method.
984   if (!spvOpcodeIsConstant(inst->opcode()) ||
985       spvOpcodeIsSpecConstant(inst->opcode())) {
986     return std::make_tuple(true, false, 0);
987   }
988 
989   if (inst->opcode() == SpvOpConstantNull) {
990     return std::make_tuple(true, true, 0);
991   }
992 
993   assert(inst->words().size() == 4);
994   return std::make_tuple(true, true, inst->word(3));
995 }
996 
ComputeFunctionToEntryPointMapping()997 void ValidationState_t::ComputeFunctionToEntryPointMapping() {
998   for (const uint32_t entry_point : entry_points()) {
999     std::stack<uint32_t> call_stack;
1000     std::set<uint32_t> visited;
1001     call_stack.push(entry_point);
1002     while (!call_stack.empty()) {
1003       const uint32_t called_func_id = call_stack.top();
1004       call_stack.pop();
1005       if (!visited.insert(called_func_id).second) continue;
1006 
1007       function_to_entry_points_[called_func_id].push_back(entry_point);
1008 
1009       const Function* called_func = function(called_func_id);
1010       if (called_func) {
1011         // Other checks should error out on this invalid SPIR-V.
1012         for (const uint32_t new_call : called_func->function_call_targets()) {
1013           call_stack.push(new_call);
1014         }
1015       }
1016     }
1017   }
1018 }
1019 
ComputeRecursiveEntryPoints()1020 void ValidationState_t::ComputeRecursiveEntryPoints() {
1021   for (const Function& func : functions()) {
1022     std::stack<uint32_t> call_stack;
1023     std::set<uint32_t> visited;
1024 
1025     for (const uint32_t new_call : func.function_call_targets()) {
1026       call_stack.push(new_call);
1027     }
1028 
1029     while (!call_stack.empty()) {
1030       const uint32_t called_func_id = call_stack.top();
1031       call_stack.pop();
1032 
1033       if (!visited.insert(called_func_id).second) continue;
1034 
1035       if (called_func_id == func.id()) {
1036         for (const uint32_t entry_point :
1037              function_to_entry_points_[called_func_id])
1038           recursive_entry_points_.insert(entry_point);
1039         break;
1040       }
1041 
1042       const Function* called_func = function(called_func_id);
1043       if (called_func) {
1044         // Other checks should error out on this invalid SPIR-V.
1045         for (const uint32_t new_call : called_func->function_call_targets()) {
1046           call_stack.push(new_call);
1047         }
1048       }
1049     }
1050   }
1051 }
1052 
FunctionEntryPoints(uint32_t func) const1053 const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
1054     uint32_t func) const {
1055   auto iter = function_to_entry_points_.find(func);
1056   if (iter == function_to_entry_points_.end()) {
1057     return empty_ids_;
1058   } else {
1059     return iter->second;
1060   }
1061 }
1062 
EntryPointReferences(uint32_t id) const1063 std::set<uint32_t> ValidationState_t::EntryPointReferences(uint32_t id) const {
1064   std::set<uint32_t> referenced_entry_points;
1065   const auto inst = FindDef(id);
1066   if (!inst) return referenced_entry_points;
1067 
1068   std::vector<const Instruction*> stack;
1069   stack.push_back(inst);
1070   while (!stack.empty()) {
1071     const auto current_inst = stack.back();
1072     stack.pop_back();
1073 
1074     if (const auto func = current_inst->function()) {
1075       // Instruction lives in a function, we can stop searching.
1076       const auto function_entry_points = FunctionEntryPoints(func->id());
1077       referenced_entry_points.insert(function_entry_points.begin(),
1078                                      function_entry_points.end());
1079     } else {
1080       // Instruction is in the global scope, keep searching its uses.
1081       for (auto pair : current_inst->uses()) {
1082         const auto next_inst = pair.first;
1083         stack.push_back(next_inst);
1084       }
1085     }
1086   }
1087 
1088   return referenced_entry_points;
1089 }
1090 
Disassemble(const Instruction & inst) const1091 std::string ValidationState_t::Disassemble(const Instruction& inst) const {
1092   const spv_parsed_instruction_t& c_inst(inst.c_inst());
1093   return Disassemble(c_inst.words, c_inst.num_words);
1094 }
1095 
Disassemble(const uint32_t * words,uint16_t num_words) const1096 std::string ValidationState_t::Disassemble(const uint32_t* words,
1097                                            uint16_t num_words) const {
1098   uint32_t disassembly_options = SPV_BINARY_TO_TEXT_OPTION_NO_HEADER |
1099                                  SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES;
1100 
1101   return spvInstructionBinaryToText(context()->target_env, words, num_words,
1102                                     words_, num_words_, disassembly_options);
1103 }
1104 
LogicallyMatch(const Instruction * lhs,const Instruction * rhs,bool check_decorations)1105 bool ValidationState_t::LogicallyMatch(const Instruction* lhs,
1106                                        const Instruction* rhs,
1107                                        bool check_decorations) {
1108   if (lhs->opcode() != rhs->opcode()) {
1109     return false;
1110   }
1111 
1112   if (check_decorations) {
1113     const auto& dec_a = id_decorations(lhs->id());
1114     const auto& dec_b = id_decorations(rhs->id());
1115 
1116     for (const auto& dec : dec_b) {
1117       if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
1118         return false;
1119       }
1120     }
1121   }
1122 
1123   if (lhs->opcode() == SpvOpTypeArray) {
1124     // Size operands must match.
1125     if (lhs->GetOperandAs<uint32_t>(2u) != rhs->GetOperandAs<uint32_t>(2u)) {
1126       return false;
1127     }
1128 
1129     // Elements must match or logically match.
1130     const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(1u);
1131     const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(1u);
1132     if (lhs_ele_id == rhs_ele_id) {
1133       return true;
1134     }
1135 
1136     const auto lhs_ele = FindDef(lhs_ele_id);
1137     const auto rhs_ele = FindDef(rhs_ele_id);
1138     if (!lhs_ele || !rhs_ele) {
1139       return false;
1140     }
1141     return LogicallyMatch(lhs_ele, rhs_ele, check_decorations);
1142   } else if (lhs->opcode() == SpvOpTypeStruct) {
1143     // Number of elements must match.
1144     if (lhs->operands().size() != rhs->operands().size()) {
1145       return false;
1146     }
1147 
1148     for (size_t i = 1u; i < lhs->operands().size(); ++i) {
1149       const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(i);
1150       const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(i);
1151       // Elements must match or logically match.
1152       if (lhs_ele_id == rhs_ele_id) {
1153         continue;
1154       }
1155 
1156       const auto lhs_ele = FindDef(lhs_ele_id);
1157       const auto rhs_ele = FindDef(rhs_ele_id);
1158       if (!lhs_ele || !rhs_ele) {
1159         return false;
1160       }
1161 
1162       if (!LogicallyMatch(lhs_ele, rhs_ele, check_decorations)) {
1163         return false;
1164       }
1165     }
1166 
1167     // All checks passed.
1168     return true;
1169   }
1170 
1171   // No other opcodes are acceptable at this point. Arrays and structs are
1172   // caught above and if they're elements are not arrays or structs they are
1173   // required to match exactly.
1174   return false;
1175 }
1176 
TracePointer(const Instruction * inst) const1177 const Instruction* ValidationState_t::TracePointer(
1178     const Instruction* inst) const {
1179   auto base_ptr = inst;
1180   while (base_ptr->opcode() == SpvOpAccessChain ||
1181          base_ptr->opcode() == SpvOpInBoundsAccessChain ||
1182          base_ptr->opcode() == SpvOpPtrAccessChain ||
1183          base_ptr->opcode() == SpvOpInBoundsPtrAccessChain ||
1184          base_ptr->opcode() == SpvOpCopyObject) {
1185     base_ptr = FindDef(base_ptr->GetOperandAs<uint32_t>(2u));
1186   }
1187   return base_ptr;
1188 }
1189 
ContainsSizedIntOrFloatType(uint32_t id,SpvOp type,uint32_t width) const1190 bool ValidationState_t::ContainsSizedIntOrFloatType(uint32_t id, SpvOp type,
1191                                                     uint32_t width) const {
1192   if (type != SpvOpTypeInt && type != SpvOpTypeFloat) return false;
1193 
1194   const auto inst = FindDef(id);
1195   if (!inst) return false;
1196 
1197   if (inst->opcode() == type) {
1198     return inst->GetOperandAs<uint32_t>(1u) == width;
1199   }
1200 
1201   switch (inst->opcode()) {
1202     case SpvOpTypeArray:
1203     case SpvOpTypeRuntimeArray:
1204     case SpvOpTypeVector:
1205     case SpvOpTypeMatrix:
1206     case SpvOpTypeImage:
1207     case SpvOpTypeSampledImage:
1208     case SpvOpTypeCooperativeMatrixNV:
1209       return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(1u), type,
1210                                          width);
1211     case SpvOpTypePointer:
1212       if (IsForwardPointer(id)) return false;
1213       return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(2u), type,
1214                                          width);
1215     case SpvOpTypeFunction:
1216     case SpvOpTypeStruct: {
1217       for (uint32_t i = 1; i < inst->operands().size(); ++i) {
1218         if (ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(i), type,
1219                                         width))
1220           return true;
1221       }
1222       return false;
1223     }
1224     default:
1225       return false;
1226   }
1227 }
1228 
ContainsLimitedUseIntOrFloatType(uint32_t id) const1229 bool ValidationState_t::ContainsLimitedUseIntOrFloatType(uint32_t id) const {
1230   if ((!HasCapability(SpvCapabilityInt16) &&
1231        ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 16)) ||
1232       (!HasCapability(SpvCapabilityInt8) &&
1233        ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 8)) ||
1234       (!HasCapability(SpvCapabilityFloat16) &&
1235        ContainsSizedIntOrFloatType(id, SpvOpTypeFloat, 16))) {
1236     return true;
1237   }
1238   return false;
1239 }
1240 
IsValidStorageClass(SpvStorageClass storage_class) const1241 bool ValidationState_t::IsValidStorageClass(
1242     SpvStorageClass storage_class) const {
1243   if (spvIsWebGPUEnv(context()->target_env)) {
1244     switch (storage_class) {
1245       case SpvStorageClassUniformConstant:
1246       case SpvStorageClassUniform:
1247       case SpvStorageClassStorageBuffer:
1248       case SpvStorageClassInput:
1249       case SpvStorageClassOutput:
1250       case SpvStorageClassImage:
1251       case SpvStorageClassWorkgroup:
1252       case SpvStorageClassPrivate:
1253       case SpvStorageClassFunction:
1254         return true;
1255       default:
1256         return false;
1257     }
1258   }
1259 
1260   if (spvIsVulkanEnv(context()->target_env)) {
1261     switch (storage_class) {
1262       case SpvStorageClassUniformConstant:
1263       case SpvStorageClassUniform:
1264       case SpvStorageClassStorageBuffer:
1265       case SpvStorageClassInput:
1266       case SpvStorageClassOutput:
1267       case SpvStorageClassImage:
1268       case SpvStorageClassWorkgroup:
1269       case SpvStorageClassPrivate:
1270       case SpvStorageClassFunction:
1271       case SpvStorageClassPushConstant:
1272       case SpvStorageClassPhysicalStorageBuffer:
1273       case SpvStorageClassRayPayloadNV:
1274       case SpvStorageClassIncomingRayPayloadNV:
1275       case SpvStorageClassHitAttributeNV:
1276       case SpvStorageClassCallableDataNV:
1277       case SpvStorageClassIncomingCallableDataNV:
1278       case SpvStorageClassShaderRecordBufferNV:
1279         return true;
1280       default:
1281         return false;
1282     }
1283   }
1284 
1285   return true;
1286 }
1287 
1288 #define VUID_WRAP(vuid) "[" #vuid "] "
1289 
1290 // Currently no 2 VUID share the same id, so no need for |reference|
VkErrorID(uint32_t id,const char *) const1291 std::string ValidationState_t::VkErrorID(uint32_t id,
1292                                          const char* /*reference*/) const {
1293   if (!spvIsVulkanEnv(context_->target_env)) {
1294     return "";
1295   }
1296 
1297   // This large switch case is only searched when an error has occured.
1298   // If an id is changed, the old case must be modified or removed. Each string
1299   // here is interpreted as being "implemented"
1300 
1301   // Clang format adds spaces between hyphens
1302   // clang-format off
1303   switch (id) {
1304     case 4181:
1305       return VUID_WRAP(VUID-BaseInstance-BaseInstance-04181);
1306     case 4182:
1307       return VUID_WRAP(VUID-BaseInstance-BaseInstance-04182);
1308     case 4183:
1309       return VUID_WRAP(VUID-BaseInstance-BaseInstance-04183);
1310     case 4184:
1311       return VUID_WRAP(VUID-BaseVertex-BaseVertex-04184);
1312     case 4185:
1313       return VUID_WRAP(VUID-BaseVertex-BaseVertex-04185);
1314     case 4186:
1315       return VUID_WRAP(VUID-BaseVertex-BaseVertex-04186);
1316     case 4187:
1317       return VUID_WRAP(VUID-ClipDistance-ClipDistance-04187);
1318     case 4191:
1319       return VUID_WRAP(VUID-ClipDistance-ClipDistance-04191);
1320     case 4196:
1321       return VUID_WRAP(VUID-CullDistance-CullDistance-04196);
1322     case 4200:
1323       return VUID_WRAP(VUID-CullDistance-CullDistance-04200);
1324     case 4205:
1325       return VUID_WRAP(VUID-DeviceIndex-DeviceIndex-04205);
1326     case 4206:
1327       return VUID_WRAP(VUID-DeviceIndex-DeviceIndex-04206);
1328     case 4207:
1329       return VUID_WRAP(VUID-DrawIndex-DrawIndex-04207);
1330     case 4208:
1331       return VUID_WRAP(VUID-DrawIndex-DrawIndex-04208);
1332     case 4209:
1333       return VUID_WRAP(VUID-DrawIndex-DrawIndex-04209);
1334     case 4210:
1335       return VUID_WRAP(VUID-FragCoord-FragCoord-04210);
1336     case 4211:
1337       return VUID_WRAP(VUID-FragCoord-FragCoord-04211);
1338     case 4212:
1339       return VUID_WRAP(VUID-FragCoord-FragCoord-04212);
1340     case 4213:
1341       return VUID_WRAP(VUID-FragDepth-FragDepth-04213);
1342     case 4214:
1343       return VUID_WRAP(VUID-FragDepth-FragDepth-04214);
1344     case 4215:
1345       return VUID_WRAP(VUID-FragDepth-FragDepth-04215);
1346     case 4216:
1347       return VUID_WRAP(VUID-FragDepth-FragDepth-04216);
1348     case 4229:
1349       return VUID_WRAP(VUID-FrontFacing-FrontFacing-04229);
1350     case 4230:
1351       return VUID_WRAP(VUID-FrontFacing-FrontFacing-04230);
1352     case 4231:
1353       return VUID_WRAP(VUID-FrontFacing-FrontFacing-04231);
1354     case 4236:
1355       return VUID_WRAP(VUID-GlobalInvocationId-GlobalInvocationId-04236);
1356     case 4237:
1357       return VUID_WRAP(VUID-GlobalInvocationId-GlobalInvocationId-04237);
1358     case 4238:
1359       return VUID_WRAP(VUID-GlobalInvocationId-GlobalInvocationId-04238);
1360     case 4239:
1361       return VUID_WRAP(VUID-HelperInvocation-HelperInvocation-04239);
1362     case 4240:
1363       return VUID_WRAP(VUID-HelperInvocation-HelperInvocation-04240);
1364     case 4241:
1365       return VUID_WRAP(VUID-HelperInvocation-HelperInvocation-04241);
1366     case 4257:
1367       return VUID_WRAP(VUID-InvocationId-InvocationId-04257);
1368     case 4258:
1369       return VUID_WRAP(VUID-InvocationId-InvocationId-04258);
1370     case 4259:
1371       return VUID_WRAP(VUID-InvocationId-InvocationId-04259);
1372     case 4263:
1373       return VUID_WRAP(VUID-InstanceIndex-InstanceIndex-04263);
1374     case 4264:
1375       return VUID_WRAP(VUID-InstanceIndex-InstanceIndex-04264);
1376     case 4265:
1377       return VUID_WRAP(VUID-InstanceIndex-InstanceIndex-04265);
1378     case 4272:
1379       return VUID_WRAP(VUID-Layer-Layer-04272);
1380     case 4274:
1381       return VUID_WRAP(VUID-Layer-Layer-04274);
1382     case 4275:
1383       return VUID_WRAP(VUID-Layer-Layer-04275);
1384     case 4276:
1385       return VUID_WRAP(VUID-Layer-Layer-04276);
1386     case 4281:
1387       return VUID_WRAP(VUID-LocalInvocationId-LocalInvocationId-04281);
1388     case 4282:
1389       return VUID_WRAP(VUID-LocalInvocationId-LocalInvocationId-04282);
1390     case 4283:
1391       return VUID_WRAP(VUID-LocalInvocationId-LocalInvocationId-04283);
1392     case 4296:
1393       return VUID_WRAP(VUID-NumWorkgroups-NumWorkgroups-04296);
1394     case 4297:
1395       return VUID_WRAP(VUID-NumWorkgroups-NumWorkgroups-04297);
1396     case 4298:
1397       return VUID_WRAP(VUID-NumWorkgroups-NumWorkgroups-04298);
1398     case 4308:
1399       return VUID_WRAP(VUID-PatchVertices-PatchVertices-04308);
1400     case 4309:
1401       return VUID_WRAP(VUID-PatchVertices-PatchVertices-04309);
1402     case 4310:
1403       return VUID_WRAP(VUID-PatchVertices-PatchVertices-04310);
1404     case 4311:
1405       return VUID_WRAP(VUID-PointCoord-PointCoord-04311);
1406     case 4312:
1407       return VUID_WRAP(VUID-PointCoord-PointCoord-04312);
1408     case 4313:
1409       return VUID_WRAP(VUID-PointCoord-PointCoord-04313);
1410     case 4314:
1411       return VUID_WRAP(VUID-PointSize-PointSize-04314);
1412     case 4315:
1413       return VUID_WRAP(VUID-PointSize-PointSize-04315);
1414     case 4316:
1415       return VUID_WRAP(VUID-PointSize-PointSize-04316);
1416     case 4317:
1417       return VUID_WRAP(VUID-PointSize-PointSize-04317);
1418     case 4318:
1419       return VUID_WRAP(VUID-Position-Position-04318);
1420     case 4320:
1421       return VUID_WRAP(VUID-Position-Position-04320);
1422     case 4321:
1423       return VUID_WRAP(VUID-Position-Position-04321);
1424     case 4330:
1425       return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04330);
1426     case 4334:
1427       return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04334);
1428     case 4337:
1429       return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04337);
1430     case 4354:
1431       return VUID_WRAP(VUID-SampleId-SampleId-04354);
1432     case 4355:
1433       return VUID_WRAP(VUID-SampleId-SampleId-04355);
1434     case 4356:
1435       return VUID_WRAP(VUID-SampleId-SampleId-04356);
1436     case 4357:
1437       return VUID_WRAP(VUID-SampleMask-SampleMask-04357);
1438     case 4358:
1439       return VUID_WRAP(VUID-SampleMask-SampleMask-04358);
1440     case 4359:
1441       return VUID_WRAP(VUID-SampleMask-SampleMask-04359);
1442     case 4360:
1443       return VUID_WRAP(VUID-SamplePosition-SamplePosition-04360);
1444     case 4361:
1445       return VUID_WRAP(VUID-SamplePosition-SamplePosition-04361);
1446     case 4362:
1447       return VUID_WRAP(VUID-SamplePosition-SamplePosition-04362);
1448     case 4387:
1449       return VUID_WRAP(VUID-TessCoord-TessCoord-04387);
1450     case 4388:
1451       return VUID_WRAP(VUID-TessCoord-TessCoord-04388);
1452     case 4389:
1453       return VUID_WRAP(VUID-TessCoord-TessCoord-04389);
1454     case 4390:
1455       return VUID_WRAP(VUID-TessLevelOuter-TessLevelOuter-04390);
1456     case 4393:
1457       return VUID_WRAP(VUID-TessLevelOuter-TessLevelOuter-04393);
1458     case 4394:
1459       return VUID_WRAP(VUID-TessLevelInner-TessLevelInner-04394);
1460     case 4397:
1461       return VUID_WRAP(VUID-TessLevelInner-TessLevelInner-04397);
1462     case 4398:
1463       return VUID_WRAP(VUID-VertexIndex-VertexIndex-04398);
1464     case 4399:
1465       return VUID_WRAP(VUID-VertexIndex-VertexIndex-04399);
1466     case 4400:
1467       return VUID_WRAP(VUID-VertexIndex-VertexIndex-04400);
1468     case 4401:
1469       return VUID_WRAP(VUID-ViewIndex-ViewIndex-04401);
1470     case 4402:
1471       return VUID_WRAP(VUID-ViewIndex-ViewIndex-04402);
1472     case 4403:
1473       return VUID_WRAP(VUID-ViewIndex-ViewIndex-04403);
1474     case 4404:
1475       return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-04404);
1476     case 4406:
1477       return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-04406);
1478     case 4407:
1479       return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-04407);
1480     case 4408:
1481       return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-04408);
1482     case 4422:
1483       return VUID_WRAP(VUID-WorkgroupId-WorkgroupId-04422);
1484     case 4423:
1485       return VUID_WRAP(VUID-WorkgroupId-WorkgroupId-04423);
1486     case 4424:
1487       return VUID_WRAP(VUID-WorkgroupId-WorkgroupId-04424);
1488     case 4425:
1489       return VUID_WRAP(VUID-WorkgroupSize-WorkgroupSize-04425);
1490     case 4426:
1491       return VUID_WRAP(VUID-WorkgroupSize-WorkgroupSize-04426);
1492     case 4427:
1493       return VUID_WRAP(VUID-WorkgroupSize-WorkgroupSize-04427);
1494     case 4484:
1495       return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-04484);
1496     case 4485:
1497       return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-04485);
1498     case 4486:
1499       return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-04486);
1500     case 4490:
1501       return VUID_WRAP(VUID-ShadingRateKHR-ShadingRateKHR-04490);
1502     case 4491:
1503       return VUID_WRAP(VUID-ShadingRateKHR-ShadingRateKHR-04491);
1504     case 4492:
1505       return VUID_WRAP(VUID-ShadingRateKHR-ShadingRateKHR-04492);
1506     default:
1507       return "";  // unknown id
1508   };
1509   // clang-format on
1510 }
1511 
1512 }  // namespace val
1513 }  // namespace spvtools
1514