1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Ensures type declarations are unique unless allowed by the specification.
16
17 #include "source/opcode.h"
18 #include "source/spirv_target_env.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22 #include "spirv/unified1/spirv.h"
23
24 namespace spvtools {
25 namespace val {
26 namespace {
27
28 // Returns, as an int64_t, the literal value from an OpConstant or the
29 // default value of an OpSpecConstant, assuming it is an integral type.
30 // For signed integers, relies the rule that literal value is sign extended
31 // to fill out to word granularity. Assumes that the constant value
32 // has
ConstantLiteralAsInt64(uint32_t width,const std::vector<uint32_t> & const_words)33 int64_t ConstantLiteralAsInt64(uint32_t width,
34 const std::vector<uint32_t>& const_words) {
35 const uint32_t lo_word = const_words[3];
36 if (width <= 32) return int32_t(lo_word);
37 assert(width <= 64);
38 assert(const_words.size() > 4);
39 const uint32_t hi_word = const_words[4]; // Must exist, per spec.
40 return static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
41 }
42
43 // Returns, as an uint64_t, the literal value from an OpConstant or the
44 // default value of an OpSpecConstant, assuming it is an integral type.
45 // For signed integers, relies the rule that literal value is sign extended
46 // to fill out to word granularity. Assumes that the constant value
47 // has
ConstantLiteralAsUint64(uint32_t width,const std::vector<uint32_t> & const_words)48 int64_t ConstantLiteralAsUint64(uint32_t width,
49 const std::vector<uint32_t>& const_words) {
50 const uint32_t lo_word = const_words[3];
51 if (width <= 32) return lo_word;
52 assert(width <= 64);
53 assert(const_words.size() > 4);
54 const uint32_t hi_word = const_words[4]; // Must exist, per spec.
55 return (uint64_t(lo_word) | uint64_t(hi_word) << 32);
56 }
57
58 // Validates that type declarations are unique, unless multiple declarations
59 // of the same data type are allowed by the specification.
60 // (see section 2.8 Types and Variables)
61 // Doesn't do anything if SPV_VAL_ignore_type_decl_unique was declared in the
62 // module.
ValidateUniqueness(ValidationState_t & _,const Instruction * inst)63 spv_result_t ValidateUniqueness(ValidationState_t& _, const Instruction* inst) {
64 if (_.HasExtension(Extension::kSPV_VALIDATOR_ignore_type_decl_unique))
65 return SPV_SUCCESS;
66
67 const auto opcode = inst->opcode();
68 if (opcode != SpvOpTypeArray && opcode != SpvOpTypeRuntimeArray &&
69 opcode != SpvOpTypeStruct && opcode != SpvOpTypePointer &&
70 !_.RegisterUniqueTypeDeclaration(inst)) {
71 return _.diag(SPV_ERROR_INVALID_DATA, inst)
72 << "Duplicate non-aggregate type declarations are not allowed. "
73 "Opcode: "
74 << spvOpcodeString(opcode) << " id: " << inst->id();
75 }
76
77 return SPV_SUCCESS;
78 }
79
ValidateTypeInt(ValidationState_t & _,const Instruction * inst)80 spv_result_t ValidateTypeInt(ValidationState_t& _, const Instruction* inst) {
81 // Validates that the number of bits specified for an Int type is valid.
82 // Scalar integer types can be parameterized only with 32-bits.
83 // Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
84 // integers, respectively.
85 auto num_bits = inst->GetOperandAs<const uint32_t>(1);
86 if (num_bits != 32) {
87 if (num_bits == 8) {
88 if (_.features().declare_int8_type) {
89 return SPV_SUCCESS;
90 }
91 return _.diag(SPV_ERROR_INVALID_DATA, inst)
92 << "Using an 8-bit integer type requires the Int8 capability,"
93 " or an extension that explicitly enables 8-bit integers.";
94 } else if (num_bits == 16) {
95 if (_.features().declare_int16_type) {
96 return SPV_SUCCESS;
97 }
98 return _.diag(SPV_ERROR_INVALID_DATA, inst)
99 << "Using a 16-bit integer type requires the Int16 capability,"
100 " or an extension that explicitly enables 16-bit integers.";
101 } else if (num_bits == 64) {
102 if (_.HasCapability(SpvCapabilityInt64)) {
103 return SPV_SUCCESS;
104 }
105 return _.diag(SPV_ERROR_INVALID_DATA, inst)
106 << "Using a 64-bit integer type requires the Int64 capability.";
107 } else {
108 return _.diag(SPV_ERROR_INVALID_DATA, inst)
109 << "Invalid number of bits (" << num_bits
110 << ") used for OpTypeInt.";
111 }
112 }
113
114 const auto signedness_index = 2;
115 const auto signedness = inst->GetOperandAs<uint32_t>(signedness_index);
116 if (signedness != 0 && signedness != 1) {
117 return _.diag(SPV_ERROR_INVALID_VALUE, inst)
118 << "OpTypeInt has invalid signedness:";
119 }
120
121 // SPIR-V Spec 2.16.3: Validation Rules for Kernel Capabilities: The
122 // Signedness in OpTypeInt must always be 0.
123 if (SpvOpTypeInt == inst->opcode() && _.HasCapability(SpvCapabilityKernel) &&
124 inst->GetOperandAs<uint32_t>(2) != 0u) {
125 return _.diag(SPV_ERROR_INVALID_BINARY, inst)
126 << "The Signedness in OpTypeInt "
127 "must always be 0 when Kernel "
128 "capability is used.";
129 }
130
131 return SPV_SUCCESS;
132 }
133
ValidateTypeFloat(ValidationState_t & _,const Instruction * inst)134 spv_result_t ValidateTypeFloat(ValidationState_t& _, const Instruction* inst) {
135 // Validates that the number of bits specified for an Int type is valid.
136 // Scalar integer types can be parameterized only with 32-bits.
137 // Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
138 // integers, respectively.
139 auto num_bits = inst->GetOperandAs<const uint32_t>(1);
140 if (num_bits == 32) {
141 return SPV_SUCCESS;
142 }
143 if (num_bits == 16) {
144 if (_.features().declare_float16_type) {
145 return SPV_SUCCESS;
146 }
147 return _.diag(SPV_ERROR_INVALID_DATA, inst)
148 << "Using a 16-bit floating point "
149 << "type requires the Float16 or Float16Buffer capability,"
150 " or an extension that explicitly enables 16-bit floating point.";
151 }
152 if (num_bits == 64) {
153 if (_.HasCapability(SpvCapabilityFloat64)) {
154 return SPV_SUCCESS;
155 }
156 return _.diag(SPV_ERROR_INVALID_DATA, inst)
157 << "Using a 64-bit floating point "
158 << "type requires the Float64 capability.";
159 }
160 return _.diag(SPV_ERROR_INVALID_DATA, inst)
161 << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat.";
162 }
163
ValidateTypeVector(ValidationState_t & _,const Instruction * inst)164 spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) {
165 const auto component_index = 1;
166 const auto component_id = inst->GetOperandAs<uint32_t>(component_index);
167 const auto component_type = _.FindDef(component_id);
168 if (!component_type || !spvOpcodeIsScalarType(component_type->opcode())) {
169 return _.diag(SPV_ERROR_INVALID_ID, inst)
170 << "OpTypeVector Component Type <id> '" << _.getIdName(component_id)
171 << "' is not a scalar type.";
172 }
173
174 // Validates that the number of components in the vector is valid.
175 // Vector types can only be parameterized as having 2, 3, or 4 components.
176 // If the Vector16 capability is added, 8 and 16 components are also allowed.
177 auto num_components = inst->GetOperandAs<const uint32_t>(2);
178 if (num_components == 2 || num_components == 3 || num_components == 4) {
179 return SPV_SUCCESS;
180 } else if (num_components == 8 || num_components == 16) {
181 if (_.HasCapability(SpvCapabilityVector16)) {
182 return SPV_SUCCESS;
183 }
184 return _.diag(SPV_ERROR_INVALID_DATA, inst)
185 << "Having " << num_components << " components for "
186 << spvOpcodeString(inst->opcode())
187 << " requires the Vector16 capability";
188 } else {
189 return _.diag(SPV_ERROR_INVALID_DATA, inst)
190 << "Illegal number of components (" << num_components << ") for "
191 << spvOpcodeString(inst->opcode());
192 }
193
194 return SPV_SUCCESS;
195 }
196
ValidateTypeMatrix(ValidationState_t & _,const Instruction * inst)197 spv_result_t ValidateTypeMatrix(ValidationState_t& _, const Instruction* inst) {
198 const auto column_type_index = 1;
199 const auto column_type_id = inst->GetOperandAs<uint32_t>(column_type_index);
200 const auto column_type = _.FindDef(column_type_id);
201 if (!column_type || SpvOpTypeVector != column_type->opcode()) {
202 return _.diag(SPV_ERROR_INVALID_ID, inst)
203 << "Columns in a matrix must be of type vector.";
204 }
205
206 // Trace back once more to find out the type of components in the vector.
207 // Operand 1 is the <id> of the type of data in the vector.
208 const auto comp_type_id = column_type->GetOperandAs<uint32_t>(1);
209 auto comp_type_instruction = _.FindDef(comp_type_id);
210 if (comp_type_instruction->opcode() != SpvOpTypeFloat) {
211 return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
212 "parameterized with "
213 "floating-point types.";
214 }
215
216 // Validates that the matrix has 2,3, or 4 columns.
217 auto num_cols = inst->GetOperandAs<const uint32_t>(2);
218 if (num_cols != 2 && num_cols != 3 && num_cols != 4) {
219 return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
220 "parameterized as having "
221 "only 2, 3, or 4 columns.";
222 }
223
224 return SPV_SUCCESS;
225 }
226
ValidateTypeArray(ValidationState_t & _,const Instruction * inst)227 spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) {
228 const auto element_type_index = 1;
229 const auto element_type_id = inst->GetOperandAs<uint32_t>(element_type_index);
230 const auto element_type = _.FindDef(element_type_id);
231 if (!element_type || !spvOpcodeGeneratesType(element_type->opcode())) {
232 return _.diag(SPV_ERROR_INVALID_ID, inst)
233 << "OpTypeArray Element Type <id> '" << _.getIdName(element_type_id)
234 << "' is not a type.";
235 }
236
237 if (element_type->opcode() == SpvOpTypeVoid) {
238 return _.diag(SPV_ERROR_INVALID_ID, inst)
239 << "OpTypeArray Element Type <id> '" << _.getIdName(element_type_id)
240 << "' is a void type.";
241 }
242
243 if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
244 element_type->opcode() == SpvOpTypeRuntimeArray) {
245 return _.diag(SPV_ERROR_INVALID_ID, inst)
246 << "OpTypeArray Element Type <id> '" << _.getIdName(element_type_id)
247 << "' is not valid in "
248 << spvLogStringForEnv(_.context()->target_env) << " environments.";
249 }
250
251 const auto length_index = 2;
252 const auto length_id = inst->GetOperandAs<uint32_t>(length_index);
253 const auto length = _.FindDef(length_id);
254 if (!length || !spvOpcodeIsConstant(length->opcode())) {
255 return _.diag(SPV_ERROR_INVALID_ID, inst)
256 << "OpTypeArray Length <id> '" << _.getIdName(length_id)
257 << "' is not a scalar constant type.";
258 }
259
260 // NOTE: Check the initialiser value of the constant
261 const auto const_inst = length->words();
262 const auto const_result_type_index = 1;
263 const auto const_result_type = _.FindDef(const_inst[const_result_type_index]);
264 if (!const_result_type || SpvOpTypeInt != const_result_type->opcode()) {
265 return _.diag(SPV_ERROR_INVALID_ID, inst)
266 << "OpTypeArray Length <id> '" << _.getIdName(length_id)
267 << "' is not a constant integer type.";
268 }
269
270 switch (length->opcode()) {
271 case SpvOpSpecConstant:
272 case SpvOpConstant: {
273 auto& type_words = const_result_type->words();
274 const bool is_signed = type_words[3] > 0;
275 const uint32_t width = type_words[2];
276 const int64_t ivalue = ConstantLiteralAsInt64(width, length->words());
277 if (ivalue == 0 || (ivalue < 0 && is_signed)) {
278 return _.diag(SPV_ERROR_INVALID_ID, inst)
279 << "OpTypeArray Length <id> '" << _.getIdName(length_id)
280 << "' default value must be at least 1: found " << ivalue;
281 }
282 if (spvIsWebGPUEnv(_.context()->target_env)) {
283 // WebGPU has maximum integer width of 32 bits, and max array size
284 // is one more than the max signed integer representation.
285 const uint64_t max_permitted = (uint64_t(1) << 31);
286 const uint64_t uvalue = ConstantLiteralAsUint64(width, length->words());
287 if (uvalue > max_permitted) {
288 return _.diag(SPV_ERROR_INVALID_ID, inst)
289 << "OpTypeArray Length <id> '" << _.getIdName(length_id)
290 << "' size exceeds max value " << max_permitted
291 << " permitted by WebGPU: got " << uvalue;
292 }
293 }
294 } break;
295 case SpvOpConstantNull:
296 return _.diag(SPV_ERROR_INVALID_ID, inst)
297 << "OpTypeArray Length <id> '" << _.getIdName(length_id)
298 << "' default value must be at least 1.";
299 case SpvOpSpecConstantOp:
300 // Assume it's OK, rather than try to evaluate the operation.
301 break;
302 default:
303 assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int");
304 }
305 return SPV_SUCCESS;
306 }
307
ValidateTypeRuntimeArray(ValidationState_t & _,const Instruction * inst)308 spv_result_t ValidateTypeRuntimeArray(ValidationState_t& _,
309 const Instruction* inst) {
310 const auto element_type_index = 1;
311 const auto element_id = inst->GetOperandAs<uint32_t>(element_type_index);
312 const auto element_type = _.FindDef(element_id);
313 if (!element_type || !spvOpcodeGeneratesType(element_type->opcode())) {
314 return _.diag(SPV_ERROR_INVALID_ID, inst)
315 << "OpTypeRuntimeArray Element Type <id> '"
316 << _.getIdName(element_id) << "' is not a type.";
317 }
318
319 if (element_type->opcode() == SpvOpTypeVoid) {
320 return _.diag(SPV_ERROR_INVALID_ID, inst)
321 << "OpTypeRuntimeArray Element Type <id> '"
322 << _.getIdName(element_id) << "' is a void type.";
323 }
324
325 if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
326 element_type->opcode() == SpvOpTypeRuntimeArray) {
327 return _.diag(SPV_ERROR_INVALID_ID, inst)
328 << "OpTypeRuntimeArray Element Type <id> '"
329 << _.getIdName(element_id) << "' is not valid in "
330 << spvLogStringForEnv(_.context()->target_env) << " environments.";
331 }
332
333 return SPV_SUCCESS;
334 }
335
ContainsOpaqueType(ValidationState_t & _,const Instruction * str)336 bool ContainsOpaqueType(ValidationState_t& _, const Instruction* str) {
337 const size_t elem_type_index = 1;
338 uint32_t elem_type_id;
339 Instruction* elem_type;
340
341 if (spvOpcodeIsBaseOpaqueType(str->opcode())) {
342 return true;
343 }
344
345 switch (str->opcode()) {
346 case SpvOpTypeArray:
347 case SpvOpTypeRuntimeArray:
348 elem_type_id = str->GetOperandAs<uint32_t>(elem_type_index);
349 elem_type = _.FindDef(elem_type_id);
350 return ContainsOpaqueType(_, elem_type);
351 case SpvOpTypeStruct:
352 for (size_t member_type_index = 1;
353 member_type_index < str->operands().size(); ++member_type_index) {
354 auto member_type_id = str->GetOperandAs<uint32_t>(member_type_index);
355 auto member_type = _.FindDef(member_type_id);
356 if (ContainsOpaqueType(_, member_type)) return true;
357 }
358 break;
359 default:
360 break;
361 }
362 return false;
363 }
364
ValidateTypeStruct(ValidationState_t & _,const Instruction * inst)365 spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) {
366 const uint32_t struct_id = inst->GetOperandAs<uint32_t>(0);
367 for (size_t member_type_index = 1;
368 member_type_index < inst->operands().size(); ++member_type_index) {
369 auto member_type_id = inst->GetOperandAs<uint32_t>(member_type_index);
370 if (member_type_id == inst->id()) {
371 return _.diag(SPV_ERROR_INVALID_ID, inst)
372 << "Structure members may not be self references";
373 }
374
375 auto member_type = _.FindDef(member_type_id);
376 if (!member_type || !spvOpcodeGeneratesType(member_type->opcode())) {
377 return _.diag(SPV_ERROR_INVALID_ID, inst)
378 << "OpTypeStruct Member Type <id> '" << _.getIdName(member_type_id)
379 << "' is not a type.";
380 }
381 if (member_type->opcode() == SpvOpTypeVoid) {
382 return _.diag(SPV_ERROR_INVALID_ID, inst)
383 << "Structures cannot contain a void type.";
384 }
385 if (SpvOpTypeStruct == member_type->opcode() &&
386 _.IsStructTypeWithBuiltInMember(member_type_id)) {
387 return _.diag(SPV_ERROR_INVALID_ID, inst)
388 << "Structure <id> " << _.getIdName(member_type_id)
389 << " contains members with BuiltIn decoration. Therefore this "
390 << "structure may not be contained as a member of another "
391 << "structure "
392 << "type. Structure <id> " << _.getIdName(struct_id)
393 << " contains structure <id> " << _.getIdName(member_type_id)
394 << ".";
395 }
396
397 if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
398 member_type->opcode() == SpvOpTypeRuntimeArray) {
399 const bool is_last_member =
400 member_type_index == inst->operands().size() - 1;
401 if (!is_last_member) {
402 return _.diag(SPV_ERROR_INVALID_ID, inst)
403 << "In " << spvLogStringForEnv(_.context()->target_env)
404 << ", OpTypeRuntimeArray must only be used for the last member "
405 "of an OpTypeStruct";
406 }
407 }
408 }
409
410 bool has_nested_blockOrBufferBlock_struct = false;
411 // Struct members start at word 2 of OpTypeStruct instruction.
412 for (size_t word_i = 2; word_i < inst->words().size(); ++word_i) {
413 auto member = inst->word(word_i);
414 auto memberTypeInstr = _.FindDef(member);
415 if (memberTypeInstr && SpvOpTypeStruct == memberTypeInstr->opcode()) {
416 if (_.HasDecoration(memberTypeInstr->id(), SpvDecorationBlock) ||
417 _.HasDecoration(memberTypeInstr->id(), SpvDecorationBufferBlock) ||
418 _.GetHasNestedBlockOrBufferBlockStruct(memberTypeInstr->id()))
419 has_nested_blockOrBufferBlock_struct = true;
420 }
421 }
422
423 _.SetHasNestedBlockOrBufferBlockStruct(inst->id(),
424 has_nested_blockOrBufferBlock_struct);
425 if (_.GetHasNestedBlockOrBufferBlockStruct(inst->id()) &&
426 (_.HasDecoration(inst->id(), SpvDecorationBufferBlock) ||
427 _.HasDecoration(inst->id(), SpvDecorationBlock))) {
428 return _.diag(SPV_ERROR_INVALID_ID, inst)
429 << "rules: A Block or BufferBlock cannot be nested within another "
430 "Block or BufferBlock. ";
431 }
432
433 std::unordered_set<uint32_t> built_in_members;
434 for (auto decoration : _.id_decorations(struct_id)) {
435 if (decoration.dec_type() == SpvDecorationBuiltIn &&
436 decoration.struct_member_index() != Decoration::kInvalidMember) {
437 built_in_members.insert(decoration.struct_member_index());
438 }
439 }
440 int num_struct_members = static_cast<int>(inst->operands().size() - 1);
441 int num_builtin_members = static_cast<int>(built_in_members.size());
442 if (num_builtin_members > 0 && num_builtin_members != num_struct_members) {
443 return _.diag(SPV_ERROR_INVALID_ID, inst)
444 << "When BuiltIn decoration is applied to a structure-type member, "
445 << "all members of that structure type must also be decorated with "
446 << "BuiltIn (No allowed mixing of built-in variables and "
447 << "non-built-in variables within a single structure). Structure id "
448 << struct_id << " does not meet this requirement.";
449 }
450 if (num_builtin_members > 0) {
451 _.RegisterStructTypeWithBuiltInMember(struct_id);
452 }
453
454 if (spvIsVulkanEnv(_.context()->target_env) &&
455 !_.options()->before_hlsl_legalization && ContainsOpaqueType(_, inst)) {
456 return _.diag(SPV_ERROR_INVALID_ID, inst)
457 << "In " << spvLogStringForEnv(_.context()->target_env)
458 << ", OpTypeStruct must not contain an opaque type.";
459 }
460
461 return SPV_SUCCESS;
462 }
463
ValidateTypePointer(ValidationState_t & _,const Instruction * inst)464 spv_result_t ValidateTypePointer(ValidationState_t& _,
465 const Instruction* inst) {
466 auto type_id = inst->GetOperandAs<uint32_t>(2);
467 auto type = _.FindDef(type_id);
468 if (!type || !spvOpcodeGeneratesType(type->opcode())) {
469 return _.diag(SPV_ERROR_INVALID_ID, inst)
470 << "OpTypePointer Type <id> '" << _.getIdName(type_id)
471 << "' is not a type.";
472 }
473 // See if this points to a storage image.
474 const auto storage_class = inst->GetOperandAs<SpvStorageClass>(1);
475 if (storage_class == SpvStorageClassUniformConstant) {
476 // Unpack an optional level of arraying.
477 if (type->opcode() == SpvOpTypeArray ||
478 type->opcode() == SpvOpTypeRuntimeArray) {
479 type_id = type->GetOperandAs<uint32_t>(1);
480 type = _.FindDef(type_id);
481 }
482 if (type->opcode() == SpvOpTypeImage) {
483 const auto sampled = type->GetOperandAs<uint32_t>(6);
484 // 2 indicates this image is known to be be used without a sampler, i.e.
485 // a storage image.
486 if (sampled == 2) _.RegisterPointerToStorageImage(inst->id());
487 }
488 }
489
490 if (!_.IsValidStorageClass(storage_class)) {
491 return _.diag(SPV_ERROR_INVALID_BINARY, inst)
492 << "Invalid storage class for target environment";
493 }
494
495 return SPV_SUCCESS;
496 }
497
ValidateTypeFunction(ValidationState_t & _,const Instruction * inst)498 spv_result_t ValidateTypeFunction(ValidationState_t& _,
499 const Instruction* inst) {
500 const auto return_type_id = inst->GetOperandAs<uint32_t>(1);
501 const auto return_type = _.FindDef(return_type_id);
502 if (!return_type || !spvOpcodeGeneratesType(return_type->opcode())) {
503 return _.diag(SPV_ERROR_INVALID_ID, inst)
504 << "OpTypeFunction Return Type <id> '" << _.getIdName(return_type_id)
505 << "' is not a type.";
506 }
507 size_t num_args = 0;
508 for (size_t param_type_index = 2; param_type_index < inst->operands().size();
509 ++param_type_index, ++num_args) {
510 const auto param_id = inst->GetOperandAs<uint32_t>(param_type_index);
511 const auto param_type = _.FindDef(param_id);
512 if (!param_type || !spvOpcodeGeneratesType(param_type->opcode())) {
513 return _.diag(SPV_ERROR_INVALID_ID, inst)
514 << "OpTypeFunction Parameter Type <id> '" << _.getIdName(param_id)
515 << "' is not a type.";
516 }
517
518 if (param_type->opcode() == SpvOpTypeVoid) {
519 return _.diag(SPV_ERROR_INVALID_ID, inst)
520 << "OpTypeFunction Parameter Type <id> '" << _.getIdName(param_id)
521 << "' cannot be OpTypeVoid.";
522 }
523 }
524 const uint32_t num_function_args_limit =
525 _.options()->universal_limits_.max_function_args;
526 if (num_args > num_function_args_limit) {
527 return _.diag(SPV_ERROR_INVALID_ID, inst)
528 << "OpTypeFunction may not take more than "
529 << num_function_args_limit << " arguments. OpTypeFunction <id> '"
530 << _.getIdName(inst->GetOperandAs<uint32_t>(0)) << "' has "
531 << num_args << " arguments.";
532 }
533
534 // The only valid uses of OpTypeFunction are in an OpFunction, debugging, or
535 // decoration instruction.
536 for (auto& pair : inst->uses()) {
537 const auto* use = pair.first;
538 if (use->opcode() != SpvOpFunction && !spvOpcodeIsDebug(use->opcode()) &&
539 !use->IsNonSemantic() && !spvOpcodeIsDecoration(use->opcode())) {
540 return _.diag(SPV_ERROR_INVALID_ID, use)
541 << "Invalid use of function type result id "
542 << _.getIdName(inst->id()) << ".";
543 }
544 }
545
546 return SPV_SUCCESS;
547 }
548
ValidateTypeForwardPointer(ValidationState_t & _,const Instruction * inst)549 spv_result_t ValidateTypeForwardPointer(ValidationState_t& _,
550 const Instruction* inst) {
551 const auto pointer_type_id = inst->GetOperandAs<uint32_t>(0);
552 const auto pointer_type_inst = _.FindDef(pointer_type_id);
553 if (pointer_type_inst->opcode() != SpvOpTypePointer) {
554 return _.diag(SPV_ERROR_INVALID_ID, inst)
555 << "Pointer type in OpTypeForwardPointer is not a pointer type.";
556 }
557
558 if (inst->GetOperandAs<uint32_t>(1) !=
559 pointer_type_inst->GetOperandAs<uint32_t>(1)) {
560 return _.diag(SPV_ERROR_INVALID_ID, inst)
561 << "Storage class in OpTypeForwardPointer does not match the "
562 << "pointer definition.";
563 }
564
565 const auto pointee_type_id = pointer_type_inst->GetOperandAs<uint32_t>(2);
566 const auto pointee_type = _.FindDef(pointee_type_id);
567 if (!pointee_type || pointee_type->opcode() != SpvOpTypeStruct) {
568 return _.diag(SPV_ERROR_INVALID_ID, inst)
569 << "Forward pointers must point to a structure";
570 }
571
572 return SPV_SUCCESS;
573 }
574
ValidateTypeCooperativeMatrixNV(ValidationState_t & _,const Instruction * inst)575 spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
576 const Instruction* inst) {
577 const auto component_type_index = 1;
578 const auto component_type_id =
579 inst->GetOperandAs<uint32_t>(component_type_index);
580 const auto component_type = _.FindDef(component_type_id);
581 if (!component_type || (SpvOpTypeFloat != component_type->opcode() &&
582 SpvOpTypeInt != component_type->opcode())) {
583 return _.diag(SPV_ERROR_INVALID_ID, inst)
584 << "OpTypeCooperativeMatrixNV Component Type <id> '"
585 << _.getIdName(component_type_id)
586 << "' is not a scalar numerical type.";
587 }
588
589 const auto scope_index = 2;
590 const auto scope_id = inst->GetOperandAs<uint32_t>(scope_index);
591 const auto scope = _.FindDef(scope_id);
592 if (!scope || !_.IsIntScalarType(scope->type_id()) ||
593 !spvOpcodeIsConstant(scope->opcode())) {
594 return _.diag(SPV_ERROR_INVALID_ID, inst)
595 << "OpTypeCooperativeMatrixNV Scope <id> '" << _.getIdName(scope_id)
596 << "' is not a constant instruction with scalar integer type.";
597 }
598
599 const auto rows_index = 3;
600 const auto rows_id = inst->GetOperandAs<uint32_t>(rows_index);
601 const auto rows = _.FindDef(rows_id);
602 if (!rows || !_.IsIntScalarType(rows->type_id()) ||
603 !spvOpcodeIsConstant(rows->opcode())) {
604 return _.diag(SPV_ERROR_INVALID_ID, inst)
605 << "OpTypeCooperativeMatrixNV Rows <id> '" << _.getIdName(rows_id)
606 << "' is not a constant instruction with scalar integer type.";
607 }
608
609 const auto cols_index = 4;
610 const auto cols_id = inst->GetOperandAs<uint32_t>(cols_index);
611 const auto cols = _.FindDef(cols_id);
612 if (!cols || !_.IsIntScalarType(cols->type_id()) ||
613 !spvOpcodeIsConstant(cols->opcode())) {
614 return _.diag(SPV_ERROR_INVALID_ID, inst)
615 << "OpTypeCooperativeMatrixNV Cols <id> '" << _.getIdName(rows_id)
616 << "' is not a constant instruction with scalar integer type.";
617 }
618
619 return SPV_SUCCESS;
620 }
621 } // namespace
622
TypePass(ValidationState_t & _,const Instruction * inst)623 spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
624 if (!spvOpcodeGeneratesType(inst->opcode()) &&
625 inst->opcode() != SpvOpTypeForwardPointer) {
626 return SPV_SUCCESS;
627 }
628
629 if (auto error = ValidateUniqueness(_, inst)) return error;
630
631 switch (inst->opcode()) {
632 case SpvOpTypeInt:
633 if (auto error = ValidateTypeInt(_, inst)) return error;
634 break;
635 case SpvOpTypeFloat:
636 if (auto error = ValidateTypeFloat(_, inst)) return error;
637 break;
638 case SpvOpTypeVector:
639 if (auto error = ValidateTypeVector(_, inst)) return error;
640 break;
641 case SpvOpTypeMatrix:
642 if (auto error = ValidateTypeMatrix(_, inst)) return error;
643 break;
644 case SpvOpTypeArray:
645 if (auto error = ValidateTypeArray(_, inst)) return error;
646 break;
647 case SpvOpTypeRuntimeArray:
648 if (auto error = ValidateTypeRuntimeArray(_, inst)) return error;
649 break;
650 case SpvOpTypeStruct:
651 if (auto error = ValidateTypeStruct(_, inst)) return error;
652 break;
653 case SpvOpTypePointer:
654 if (auto error = ValidateTypePointer(_, inst)) return error;
655 break;
656 case SpvOpTypeFunction:
657 if (auto error = ValidateTypeFunction(_, inst)) return error;
658 break;
659 case SpvOpTypeForwardPointer:
660 if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
661 break;
662 case SpvOpTypeCooperativeMatrixNV:
663 if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
664 break;
665 default:
666 break;
667 }
668
669 return SPV_SUCCESS;
670 }
671
672 } // namespace val
673 } // namespace spvtools
674