1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Ensures type declarations are unique unless allowed by the specification.
16 
17 #include "source/opcode.h"
18 #include "source/spirv_target_env.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22 #include "spirv/unified1/spirv.h"
23 
24 namespace spvtools {
25 namespace val {
26 namespace {
27 
28 // Returns, as an int64_t, the literal value from an OpConstant or the
29 // default value of an OpSpecConstant, assuming it is an integral type.
30 // For signed integers, relies the rule that literal value is sign extended
31 // to fill out to word granularity.  Assumes that the constant value
32 // has
ConstantLiteralAsInt64(uint32_t width,const std::vector<uint32_t> & const_words)33 int64_t ConstantLiteralAsInt64(uint32_t width,
34                                const std::vector<uint32_t>& const_words) {
35   const uint32_t lo_word = const_words[3];
36   if (width <= 32) return int32_t(lo_word);
37   assert(width <= 64);
38   assert(const_words.size() > 4);
39   const uint32_t hi_word = const_words[4];  // Must exist, per spec.
40   return static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
41 }
42 
43 // Returns, as an uint64_t, the literal value from an OpConstant or the
44 // default value of an OpSpecConstant, assuming it is an integral type.
45 // For signed integers, relies the rule that literal value is sign extended
46 // to fill out to word granularity.  Assumes that the constant value
47 // has
ConstantLiteralAsUint64(uint32_t width,const std::vector<uint32_t> & const_words)48 int64_t ConstantLiteralAsUint64(uint32_t width,
49                                 const std::vector<uint32_t>& const_words) {
50   const uint32_t lo_word = const_words[3];
51   if (width <= 32) return lo_word;
52   assert(width <= 64);
53   assert(const_words.size() > 4);
54   const uint32_t hi_word = const_words[4];  // Must exist, per spec.
55   return (uint64_t(lo_word) | uint64_t(hi_word) << 32);
56 }
57 
58 // Validates that type declarations are unique, unless multiple declarations
59 // of the same data type are allowed by the specification.
60 // (see section 2.8 Types and Variables)
61 // Doesn't do anything if SPV_VAL_ignore_type_decl_unique was declared in the
62 // module.
ValidateUniqueness(ValidationState_t & _,const Instruction * inst)63 spv_result_t ValidateUniqueness(ValidationState_t& _, const Instruction* inst) {
64   if (_.HasExtension(Extension::kSPV_VALIDATOR_ignore_type_decl_unique))
65     return SPV_SUCCESS;
66 
67   const auto opcode = inst->opcode();
68   if (opcode != SpvOpTypeArray && opcode != SpvOpTypeRuntimeArray &&
69       opcode != SpvOpTypeStruct && opcode != SpvOpTypePointer &&
70       !_.RegisterUniqueTypeDeclaration(inst)) {
71     return _.diag(SPV_ERROR_INVALID_DATA, inst)
72            << "Duplicate non-aggregate type declarations are not allowed. "
73               "Opcode: "
74            << spvOpcodeString(opcode) << " id: " << inst->id();
75   }
76 
77   return SPV_SUCCESS;
78 }
79 
ValidateTypeInt(ValidationState_t & _,const Instruction * inst)80 spv_result_t ValidateTypeInt(ValidationState_t& _, const Instruction* inst) {
81   // Validates that the number of bits specified for an Int type is valid.
82   // Scalar integer types can be parameterized only with 32-bits.
83   // Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
84   // integers, respectively.
85   auto num_bits = inst->GetOperandAs<const uint32_t>(1);
86   if (num_bits != 32) {
87     if (num_bits == 8) {
88       if (_.features().declare_int8_type) {
89         return SPV_SUCCESS;
90       }
91       return _.diag(SPV_ERROR_INVALID_DATA, inst)
92              << "Using an 8-bit integer type requires the Int8 capability,"
93                 " or an extension that explicitly enables 8-bit integers.";
94     } else if (num_bits == 16) {
95       if (_.features().declare_int16_type) {
96         return SPV_SUCCESS;
97       }
98       return _.diag(SPV_ERROR_INVALID_DATA, inst)
99              << "Using a 16-bit integer type requires the Int16 capability,"
100                 " or an extension that explicitly enables 16-bit integers.";
101     } else if (num_bits == 64) {
102       if (_.HasCapability(SpvCapabilityInt64)) {
103         return SPV_SUCCESS;
104       }
105       return _.diag(SPV_ERROR_INVALID_DATA, inst)
106              << "Using a 64-bit integer type requires the Int64 capability.";
107     } else {
108       return _.diag(SPV_ERROR_INVALID_DATA, inst)
109              << "Invalid number of bits (" << num_bits
110              << ") used for OpTypeInt.";
111     }
112   }
113 
114   const auto signedness_index = 2;
115   const auto signedness = inst->GetOperandAs<uint32_t>(signedness_index);
116   if (signedness != 0 && signedness != 1) {
117     return _.diag(SPV_ERROR_INVALID_VALUE, inst)
118            << "OpTypeInt has invalid signedness:";
119   }
120 
121   // SPIR-V Spec 2.16.3: Validation Rules for Kernel Capabilities: The
122   // Signedness in OpTypeInt must always be 0.
123   if (SpvOpTypeInt == inst->opcode() && _.HasCapability(SpvCapabilityKernel) &&
124       inst->GetOperandAs<uint32_t>(2) != 0u) {
125     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
126            << "The Signedness in OpTypeInt "
127               "must always be 0 when Kernel "
128               "capability is used.";
129   }
130 
131   return SPV_SUCCESS;
132 }
133 
ValidateTypeFloat(ValidationState_t & _,const Instruction * inst)134 spv_result_t ValidateTypeFloat(ValidationState_t& _, const Instruction* inst) {
135   // Validates that the number of bits specified for an Int type is valid.
136   // Scalar integer types can be parameterized only with 32-bits.
137   // Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
138   // integers, respectively.
139   auto num_bits = inst->GetOperandAs<const uint32_t>(1);
140   if (num_bits == 32) {
141     return SPV_SUCCESS;
142   }
143   if (num_bits == 16) {
144     if (_.features().declare_float16_type) {
145       return SPV_SUCCESS;
146     }
147     return _.diag(SPV_ERROR_INVALID_DATA, inst)
148            << "Using a 16-bit floating point "
149            << "type requires the Float16 or Float16Buffer capability,"
150               " or an extension that explicitly enables 16-bit floating point.";
151   }
152   if (num_bits == 64) {
153     if (_.HasCapability(SpvCapabilityFloat64)) {
154       return SPV_SUCCESS;
155     }
156     return _.diag(SPV_ERROR_INVALID_DATA, inst)
157            << "Using a 64-bit floating point "
158            << "type requires the Float64 capability.";
159   }
160   return _.diag(SPV_ERROR_INVALID_DATA, inst)
161          << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat.";
162 }
163 
ValidateTypeVector(ValidationState_t & _,const Instruction * inst)164 spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) {
165   const auto component_index = 1;
166   const auto component_id = inst->GetOperandAs<uint32_t>(component_index);
167   const auto component_type = _.FindDef(component_id);
168   if (!component_type || !spvOpcodeIsScalarType(component_type->opcode())) {
169     return _.diag(SPV_ERROR_INVALID_ID, inst)
170            << "OpTypeVector Component Type <id> '" << _.getIdName(component_id)
171            << "' is not a scalar type.";
172   }
173 
174   // Validates that the number of components in the vector is valid.
175   // Vector types can only be parameterized as having 2, 3, or 4 components.
176   // If the Vector16 capability is added, 8 and 16 components are also allowed.
177   auto num_components = inst->GetOperandAs<const uint32_t>(2);
178   if (num_components == 2 || num_components == 3 || num_components == 4) {
179     return SPV_SUCCESS;
180   } else if (num_components == 8 || num_components == 16) {
181     if (_.HasCapability(SpvCapabilityVector16)) {
182       return SPV_SUCCESS;
183     }
184     return _.diag(SPV_ERROR_INVALID_DATA, inst)
185            << "Having " << num_components << " components for "
186            << spvOpcodeString(inst->opcode())
187            << " requires the Vector16 capability";
188   } else {
189     return _.diag(SPV_ERROR_INVALID_DATA, inst)
190            << "Illegal number of components (" << num_components << ") for "
191            << spvOpcodeString(inst->opcode());
192   }
193 
194   return SPV_SUCCESS;
195 }
196 
ValidateTypeMatrix(ValidationState_t & _,const Instruction * inst)197 spv_result_t ValidateTypeMatrix(ValidationState_t& _, const Instruction* inst) {
198   const auto column_type_index = 1;
199   const auto column_type_id = inst->GetOperandAs<uint32_t>(column_type_index);
200   const auto column_type = _.FindDef(column_type_id);
201   if (!column_type || SpvOpTypeVector != column_type->opcode()) {
202     return _.diag(SPV_ERROR_INVALID_ID, inst)
203            << "Columns in a matrix must be of type vector.";
204   }
205 
206   // Trace back once more to find out the type of components in the vector.
207   // Operand 1 is the <id> of the type of data in the vector.
208   const auto comp_type_id = column_type->GetOperandAs<uint32_t>(1);
209   auto comp_type_instruction = _.FindDef(comp_type_id);
210   if (comp_type_instruction->opcode() != SpvOpTypeFloat) {
211     return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
212                                                    "parameterized with "
213                                                    "floating-point types.";
214   }
215 
216   // Validates that the matrix has 2,3, or 4 columns.
217   auto num_cols = inst->GetOperandAs<const uint32_t>(2);
218   if (num_cols != 2 && num_cols != 3 && num_cols != 4) {
219     return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
220                                                    "parameterized as having "
221                                                    "only 2, 3, or 4 columns.";
222   }
223 
224   return SPV_SUCCESS;
225 }
226 
ValidateTypeArray(ValidationState_t & _,const Instruction * inst)227 spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) {
228   const auto element_type_index = 1;
229   const auto element_type_id = inst->GetOperandAs<uint32_t>(element_type_index);
230   const auto element_type = _.FindDef(element_type_id);
231   if (!element_type || !spvOpcodeGeneratesType(element_type->opcode())) {
232     return _.diag(SPV_ERROR_INVALID_ID, inst)
233            << "OpTypeArray Element Type <id> '" << _.getIdName(element_type_id)
234            << "' is not a type.";
235   }
236 
237   if (element_type->opcode() == SpvOpTypeVoid) {
238     return _.diag(SPV_ERROR_INVALID_ID, inst)
239            << "OpTypeArray Element Type <id> '" << _.getIdName(element_type_id)
240            << "' is a void type.";
241   }
242 
243   if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
244       element_type->opcode() == SpvOpTypeRuntimeArray) {
245     return _.diag(SPV_ERROR_INVALID_ID, inst)
246            << "OpTypeArray Element Type <id> '" << _.getIdName(element_type_id)
247            << "' is not valid in "
248            << spvLogStringForEnv(_.context()->target_env) << " environments.";
249   }
250 
251   const auto length_index = 2;
252   const auto length_id = inst->GetOperandAs<uint32_t>(length_index);
253   const auto length = _.FindDef(length_id);
254   if (!length || !spvOpcodeIsConstant(length->opcode())) {
255     return _.diag(SPV_ERROR_INVALID_ID, inst)
256            << "OpTypeArray Length <id> '" << _.getIdName(length_id)
257            << "' is not a scalar constant type.";
258   }
259 
260   // NOTE: Check the initialiser value of the constant
261   const auto const_inst = length->words();
262   const auto const_result_type_index = 1;
263   const auto const_result_type = _.FindDef(const_inst[const_result_type_index]);
264   if (!const_result_type || SpvOpTypeInt != const_result_type->opcode()) {
265     return _.diag(SPV_ERROR_INVALID_ID, inst)
266            << "OpTypeArray Length <id> '" << _.getIdName(length_id)
267            << "' is not a constant integer type.";
268   }
269 
270   switch (length->opcode()) {
271     case SpvOpSpecConstant:
272     case SpvOpConstant: {
273       auto& type_words = const_result_type->words();
274       const bool is_signed = type_words[3] > 0;
275       const uint32_t width = type_words[2];
276       const int64_t ivalue = ConstantLiteralAsInt64(width, length->words());
277       if (ivalue == 0 || (ivalue < 0 && is_signed)) {
278         return _.diag(SPV_ERROR_INVALID_ID, inst)
279                << "OpTypeArray Length <id> '" << _.getIdName(length_id)
280                << "' default value must be at least 1: found " << ivalue;
281       }
282       if (spvIsWebGPUEnv(_.context()->target_env)) {
283         // WebGPU has maximum integer width of 32 bits, and max array size
284         // is one more than the max signed integer representation.
285         const uint64_t max_permitted = (uint64_t(1) << 31);
286         const uint64_t uvalue = ConstantLiteralAsUint64(width, length->words());
287         if (uvalue > max_permitted) {
288           return _.diag(SPV_ERROR_INVALID_ID, inst)
289                  << "OpTypeArray Length <id> '" << _.getIdName(length_id)
290                  << "' size exceeds max value " << max_permitted
291                  << " permitted by WebGPU: got " << uvalue;
292         }
293       }
294     } break;
295     case SpvOpConstantNull:
296       return _.diag(SPV_ERROR_INVALID_ID, inst)
297              << "OpTypeArray Length <id> '" << _.getIdName(length_id)
298              << "' default value must be at least 1.";
299     case SpvOpSpecConstantOp:
300       // Assume it's OK, rather than try to evaluate the operation.
301       break;
302     default:
303       assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int");
304   }
305   return SPV_SUCCESS;
306 }
307 
ValidateTypeRuntimeArray(ValidationState_t & _,const Instruction * inst)308 spv_result_t ValidateTypeRuntimeArray(ValidationState_t& _,
309                                       const Instruction* inst) {
310   const auto element_type_index = 1;
311   const auto element_id = inst->GetOperandAs<uint32_t>(element_type_index);
312   const auto element_type = _.FindDef(element_id);
313   if (!element_type || !spvOpcodeGeneratesType(element_type->opcode())) {
314     return _.diag(SPV_ERROR_INVALID_ID, inst)
315            << "OpTypeRuntimeArray Element Type <id> '"
316            << _.getIdName(element_id) << "' is not a type.";
317   }
318 
319   if (element_type->opcode() == SpvOpTypeVoid) {
320     return _.diag(SPV_ERROR_INVALID_ID, inst)
321            << "OpTypeRuntimeArray Element Type <id> '"
322            << _.getIdName(element_id) << "' is a void type.";
323   }
324 
325   if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
326       element_type->opcode() == SpvOpTypeRuntimeArray) {
327     return _.diag(SPV_ERROR_INVALID_ID, inst)
328            << "OpTypeRuntimeArray Element Type <id> '"
329            << _.getIdName(element_id) << "' is not valid in "
330            << spvLogStringForEnv(_.context()->target_env) << " environments.";
331   }
332 
333   return SPV_SUCCESS;
334 }
335 
ContainsOpaqueType(ValidationState_t & _,const Instruction * str)336 bool ContainsOpaqueType(ValidationState_t& _, const Instruction* str) {
337   const size_t elem_type_index = 1;
338   uint32_t elem_type_id;
339   Instruction* elem_type;
340 
341   if (spvOpcodeIsBaseOpaqueType(str->opcode())) {
342     return true;
343   }
344 
345   switch (str->opcode()) {
346     case SpvOpTypeArray:
347     case SpvOpTypeRuntimeArray:
348       elem_type_id = str->GetOperandAs<uint32_t>(elem_type_index);
349       elem_type = _.FindDef(elem_type_id);
350       return ContainsOpaqueType(_, elem_type);
351     case SpvOpTypeStruct:
352       for (size_t member_type_index = 1;
353            member_type_index < str->operands().size(); ++member_type_index) {
354         auto member_type_id = str->GetOperandAs<uint32_t>(member_type_index);
355         auto member_type = _.FindDef(member_type_id);
356         if (ContainsOpaqueType(_, member_type)) return true;
357       }
358       break;
359     default:
360       break;
361   }
362   return false;
363 }
364 
ValidateTypeStruct(ValidationState_t & _,const Instruction * inst)365 spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) {
366   const uint32_t struct_id = inst->GetOperandAs<uint32_t>(0);
367   for (size_t member_type_index = 1;
368        member_type_index < inst->operands().size(); ++member_type_index) {
369     auto member_type_id = inst->GetOperandAs<uint32_t>(member_type_index);
370     if (member_type_id == inst->id()) {
371       return _.diag(SPV_ERROR_INVALID_ID, inst)
372              << "Structure members may not be self references";
373     }
374 
375     auto member_type = _.FindDef(member_type_id);
376     if (!member_type || !spvOpcodeGeneratesType(member_type->opcode())) {
377       return _.diag(SPV_ERROR_INVALID_ID, inst)
378              << "OpTypeStruct Member Type <id> '" << _.getIdName(member_type_id)
379              << "' is not a type.";
380     }
381     if (member_type->opcode() == SpvOpTypeVoid) {
382       return _.diag(SPV_ERROR_INVALID_ID, inst)
383              << "Structures cannot contain a void type.";
384     }
385     if (SpvOpTypeStruct == member_type->opcode() &&
386         _.IsStructTypeWithBuiltInMember(member_type_id)) {
387       return _.diag(SPV_ERROR_INVALID_ID, inst)
388              << "Structure <id> " << _.getIdName(member_type_id)
389              << " contains members with BuiltIn decoration. Therefore this "
390              << "structure may not be contained as a member of another "
391              << "structure "
392              << "type. Structure <id> " << _.getIdName(struct_id)
393              << " contains structure <id> " << _.getIdName(member_type_id)
394              << ".";
395     }
396 
397     if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
398         member_type->opcode() == SpvOpTypeRuntimeArray) {
399       const bool is_last_member =
400           member_type_index == inst->operands().size() - 1;
401       if (!is_last_member) {
402         return _.diag(SPV_ERROR_INVALID_ID, inst)
403                << "In " << spvLogStringForEnv(_.context()->target_env)
404                << ", OpTypeRuntimeArray must only be used for the last member "
405                   "of an OpTypeStruct";
406       }
407     }
408   }
409 
410   bool has_nested_blockOrBufferBlock_struct = false;
411   // Struct members start at word 2 of OpTypeStruct instruction.
412   for (size_t word_i = 2; word_i < inst->words().size(); ++word_i) {
413     auto member = inst->word(word_i);
414     auto memberTypeInstr = _.FindDef(member);
415     if (memberTypeInstr && SpvOpTypeStruct == memberTypeInstr->opcode()) {
416       if (_.HasDecoration(memberTypeInstr->id(), SpvDecorationBlock) ||
417           _.HasDecoration(memberTypeInstr->id(), SpvDecorationBufferBlock) ||
418           _.GetHasNestedBlockOrBufferBlockStruct(memberTypeInstr->id()))
419         has_nested_blockOrBufferBlock_struct = true;
420     }
421   }
422 
423   _.SetHasNestedBlockOrBufferBlockStruct(inst->id(),
424                                          has_nested_blockOrBufferBlock_struct);
425   if (_.GetHasNestedBlockOrBufferBlockStruct(inst->id()) &&
426       (_.HasDecoration(inst->id(), SpvDecorationBufferBlock) ||
427        _.HasDecoration(inst->id(), SpvDecorationBlock))) {
428     return _.diag(SPV_ERROR_INVALID_ID, inst)
429            << "rules: A Block or BufferBlock cannot be nested within another "
430               "Block or BufferBlock. ";
431   }
432 
433   std::unordered_set<uint32_t> built_in_members;
434   for (auto decoration : _.id_decorations(struct_id)) {
435     if (decoration.dec_type() == SpvDecorationBuiltIn &&
436         decoration.struct_member_index() != Decoration::kInvalidMember) {
437       built_in_members.insert(decoration.struct_member_index());
438     }
439   }
440   int num_struct_members = static_cast<int>(inst->operands().size() - 1);
441   int num_builtin_members = static_cast<int>(built_in_members.size());
442   if (num_builtin_members > 0 && num_builtin_members != num_struct_members) {
443     return _.diag(SPV_ERROR_INVALID_ID, inst)
444            << "When BuiltIn decoration is applied to a structure-type member, "
445            << "all members of that structure type must also be decorated with "
446            << "BuiltIn (No allowed mixing of built-in variables and "
447            << "non-built-in variables within a single structure). Structure id "
448            << struct_id << " does not meet this requirement.";
449   }
450   if (num_builtin_members > 0) {
451     _.RegisterStructTypeWithBuiltInMember(struct_id);
452   }
453 
454   if (spvIsVulkanEnv(_.context()->target_env) &&
455       !_.options()->before_hlsl_legalization && ContainsOpaqueType(_, inst)) {
456     return _.diag(SPV_ERROR_INVALID_ID, inst)
457            << "In " << spvLogStringForEnv(_.context()->target_env)
458            << ", OpTypeStruct must not contain an opaque type.";
459   }
460 
461   return SPV_SUCCESS;
462 }
463 
ValidateTypePointer(ValidationState_t & _,const Instruction * inst)464 spv_result_t ValidateTypePointer(ValidationState_t& _,
465                                  const Instruction* inst) {
466   auto type_id = inst->GetOperandAs<uint32_t>(2);
467   auto type = _.FindDef(type_id);
468   if (!type || !spvOpcodeGeneratesType(type->opcode())) {
469     return _.diag(SPV_ERROR_INVALID_ID, inst)
470            << "OpTypePointer Type <id> '" << _.getIdName(type_id)
471            << "' is not a type.";
472   }
473   // See if this points to a storage image.
474   const auto storage_class = inst->GetOperandAs<SpvStorageClass>(1);
475   if (storage_class == SpvStorageClassUniformConstant) {
476     // Unpack an optional level of arraying.
477     if (type->opcode() == SpvOpTypeArray ||
478         type->opcode() == SpvOpTypeRuntimeArray) {
479       type_id = type->GetOperandAs<uint32_t>(1);
480       type = _.FindDef(type_id);
481     }
482     if (type->opcode() == SpvOpTypeImage) {
483       const auto sampled = type->GetOperandAs<uint32_t>(6);
484       // 2 indicates this image is known to be be used without a sampler, i.e.
485       // a storage image.
486       if (sampled == 2) _.RegisterPointerToStorageImage(inst->id());
487     }
488   }
489 
490   if (!_.IsValidStorageClass(storage_class)) {
491     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
492            << "Invalid storage class for target environment";
493   }
494 
495   return SPV_SUCCESS;
496 }
497 
ValidateTypeFunction(ValidationState_t & _,const Instruction * inst)498 spv_result_t ValidateTypeFunction(ValidationState_t& _,
499                                   const Instruction* inst) {
500   const auto return_type_id = inst->GetOperandAs<uint32_t>(1);
501   const auto return_type = _.FindDef(return_type_id);
502   if (!return_type || !spvOpcodeGeneratesType(return_type->opcode())) {
503     return _.diag(SPV_ERROR_INVALID_ID, inst)
504            << "OpTypeFunction Return Type <id> '" << _.getIdName(return_type_id)
505            << "' is not a type.";
506   }
507   size_t num_args = 0;
508   for (size_t param_type_index = 2; param_type_index < inst->operands().size();
509        ++param_type_index, ++num_args) {
510     const auto param_id = inst->GetOperandAs<uint32_t>(param_type_index);
511     const auto param_type = _.FindDef(param_id);
512     if (!param_type || !spvOpcodeGeneratesType(param_type->opcode())) {
513       return _.diag(SPV_ERROR_INVALID_ID, inst)
514              << "OpTypeFunction Parameter Type <id> '" << _.getIdName(param_id)
515              << "' is not a type.";
516     }
517 
518     if (param_type->opcode() == SpvOpTypeVoid) {
519       return _.diag(SPV_ERROR_INVALID_ID, inst)
520              << "OpTypeFunction Parameter Type <id> '" << _.getIdName(param_id)
521              << "' cannot be OpTypeVoid.";
522     }
523   }
524   const uint32_t num_function_args_limit =
525       _.options()->universal_limits_.max_function_args;
526   if (num_args > num_function_args_limit) {
527     return _.diag(SPV_ERROR_INVALID_ID, inst)
528            << "OpTypeFunction may not take more than "
529            << num_function_args_limit << " arguments. OpTypeFunction <id> '"
530            << _.getIdName(inst->GetOperandAs<uint32_t>(0)) << "' has "
531            << num_args << " arguments.";
532   }
533 
534   // The only valid uses of OpTypeFunction are in an OpFunction, debugging, or
535   // decoration instruction.
536   for (auto& pair : inst->uses()) {
537     const auto* use = pair.first;
538     if (use->opcode() != SpvOpFunction && !spvOpcodeIsDebug(use->opcode()) &&
539         !use->IsNonSemantic() && !spvOpcodeIsDecoration(use->opcode())) {
540       return _.diag(SPV_ERROR_INVALID_ID, use)
541              << "Invalid use of function type result id "
542              << _.getIdName(inst->id()) << ".";
543     }
544   }
545 
546   return SPV_SUCCESS;
547 }
548 
ValidateTypeForwardPointer(ValidationState_t & _,const Instruction * inst)549 spv_result_t ValidateTypeForwardPointer(ValidationState_t& _,
550                                         const Instruction* inst) {
551   const auto pointer_type_id = inst->GetOperandAs<uint32_t>(0);
552   const auto pointer_type_inst = _.FindDef(pointer_type_id);
553   if (pointer_type_inst->opcode() != SpvOpTypePointer) {
554     return _.diag(SPV_ERROR_INVALID_ID, inst)
555            << "Pointer type in OpTypeForwardPointer is not a pointer type.";
556   }
557 
558   if (inst->GetOperandAs<uint32_t>(1) !=
559       pointer_type_inst->GetOperandAs<uint32_t>(1)) {
560     return _.diag(SPV_ERROR_INVALID_ID, inst)
561            << "Storage class in OpTypeForwardPointer does not match the "
562            << "pointer definition.";
563   }
564 
565   const auto pointee_type_id = pointer_type_inst->GetOperandAs<uint32_t>(2);
566   const auto pointee_type = _.FindDef(pointee_type_id);
567   if (!pointee_type || pointee_type->opcode() != SpvOpTypeStruct) {
568     return _.diag(SPV_ERROR_INVALID_ID, inst)
569            << "Forward pointers must point to a structure";
570   }
571 
572   return SPV_SUCCESS;
573 }
574 
ValidateTypeCooperativeMatrixNV(ValidationState_t & _,const Instruction * inst)575 spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
576                                              const Instruction* inst) {
577   const auto component_type_index = 1;
578   const auto component_type_id =
579       inst->GetOperandAs<uint32_t>(component_type_index);
580   const auto component_type = _.FindDef(component_type_id);
581   if (!component_type || (SpvOpTypeFloat != component_type->opcode() &&
582                           SpvOpTypeInt != component_type->opcode())) {
583     return _.diag(SPV_ERROR_INVALID_ID, inst)
584            << "OpTypeCooperativeMatrixNV Component Type <id> '"
585            << _.getIdName(component_type_id)
586            << "' is not a scalar numerical type.";
587   }
588 
589   const auto scope_index = 2;
590   const auto scope_id = inst->GetOperandAs<uint32_t>(scope_index);
591   const auto scope = _.FindDef(scope_id);
592   if (!scope || !_.IsIntScalarType(scope->type_id()) ||
593       !spvOpcodeIsConstant(scope->opcode())) {
594     return _.diag(SPV_ERROR_INVALID_ID, inst)
595            << "OpTypeCooperativeMatrixNV Scope <id> '" << _.getIdName(scope_id)
596            << "' is not a constant instruction with scalar integer type.";
597   }
598 
599   const auto rows_index = 3;
600   const auto rows_id = inst->GetOperandAs<uint32_t>(rows_index);
601   const auto rows = _.FindDef(rows_id);
602   if (!rows || !_.IsIntScalarType(rows->type_id()) ||
603       !spvOpcodeIsConstant(rows->opcode())) {
604     return _.diag(SPV_ERROR_INVALID_ID, inst)
605            << "OpTypeCooperativeMatrixNV Rows <id> '" << _.getIdName(rows_id)
606            << "' is not a constant instruction with scalar integer type.";
607   }
608 
609   const auto cols_index = 4;
610   const auto cols_id = inst->GetOperandAs<uint32_t>(cols_index);
611   const auto cols = _.FindDef(cols_id);
612   if (!cols || !_.IsIntScalarType(cols->type_id()) ||
613       !spvOpcodeIsConstant(cols->opcode())) {
614     return _.diag(SPV_ERROR_INVALID_ID, inst)
615            << "OpTypeCooperativeMatrixNV Cols <id> '" << _.getIdName(rows_id)
616            << "' is not a constant instruction with scalar integer type.";
617   }
618 
619   return SPV_SUCCESS;
620 }
621 }  // namespace
622 
TypePass(ValidationState_t & _,const Instruction * inst)623 spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
624   if (!spvOpcodeGeneratesType(inst->opcode()) &&
625       inst->opcode() != SpvOpTypeForwardPointer) {
626     return SPV_SUCCESS;
627   }
628 
629   if (auto error = ValidateUniqueness(_, inst)) return error;
630 
631   switch (inst->opcode()) {
632     case SpvOpTypeInt:
633       if (auto error = ValidateTypeInt(_, inst)) return error;
634       break;
635     case SpvOpTypeFloat:
636       if (auto error = ValidateTypeFloat(_, inst)) return error;
637       break;
638     case SpvOpTypeVector:
639       if (auto error = ValidateTypeVector(_, inst)) return error;
640       break;
641     case SpvOpTypeMatrix:
642       if (auto error = ValidateTypeMatrix(_, inst)) return error;
643       break;
644     case SpvOpTypeArray:
645       if (auto error = ValidateTypeArray(_, inst)) return error;
646       break;
647     case SpvOpTypeRuntimeArray:
648       if (auto error = ValidateTypeRuntimeArray(_, inst)) return error;
649       break;
650     case SpvOpTypeStruct:
651       if (auto error = ValidateTypeStruct(_, inst)) return error;
652       break;
653     case SpvOpTypePointer:
654       if (auto error = ValidateTypePointer(_, inst)) return error;
655       break;
656     case SpvOpTypeFunction:
657       if (auto error = ValidateTypeFunction(_, inst)) return error;
658       break;
659     case SpvOpTypeForwardPointer:
660       if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
661       break;
662     case SpvOpTypeCooperativeMatrixNV:
663       if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
664       break;
665     default:
666       break;
667   }
668 
669   return SPV_SUCCESS;
670 }
671 
672 }  // namespace val
673 }  // namespace spvtools
674