1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opcode.h"
16 #include "source/val/instruction.h"
17 #include "source/val/validate.h"
18 #include "source/val/validation_state.h"
19 
20 namespace spvtools {
21 namespace val {
22 namespace {
23 
ValidateConstantBool(ValidationState_t & _,const Instruction * inst)24 spv_result_t ValidateConstantBool(ValidationState_t& _,
25                                   const Instruction* inst) {
26   auto type = _.FindDef(inst->type_id());
27   if (!type || type->opcode() != SpvOpTypeBool) {
28     return _.diag(SPV_ERROR_INVALID_ID, inst)
29            << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '"
30            << _.getIdName(inst->type_id()) << "' is not a boolean type.";
31   }
32 
33   return SPV_SUCCESS;
34 }
35 
ValidateConstantComposite(ValidationState_t & _,const Instruction * inst)36 spv_result_t ValidateConstantComposite(ValidationState_t& _,
37                                        const Instruction* inst) {
38   std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
39 
40   const auto result_type = _.FindDef(inst->type_id());
41   if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
42     return _.diag(SPV_ERROR_INVALID_ID, inst)
43            << opcode_name << " Result Type <id> '"
44            << _.getIdName(inst->type_id()) << "' is not a composite type.";
45   }
46 
47   const auto constituent_count = inst->words().size() - 3;
48   switch (result_type->opcode()) {
49     case SpvOpTypeVector: {
50       const auto component_count = result_type->GetOperandAs<uint32_t>(2);
51       if (component_count != constituent_count) {
52         // TODO: Output ID's on diagnostic
53         return _.diag(SPV_ERROR_INVALID_ID, inst)
54                << opcode_name
55                << " Constituent <id> count does not match "
56                   "Result Type <id> '"
57                << _.getIdName(result_type->id())
58                << "'s vector component count.";
59       }
60       const auto component_type =
61           _.FindDef(result_type->GetOperandAs<uint32_t>(1));
62       if (!component_type) {
63         return _.diag(SPV_ERROR_INVALID_ID, result_type)
64                << "Component type is not defined.";
65       }
66       for (size_t constituent_index = 2;
67            constituent_index < inst->operands().size(); constituent_index++) {
68         const auto constituent_id =
69             inst->GetOperandAs<uint32_t>(constituent_index);
70         const auto constituent = _.FindDef(constituent_id);
71         if (!constituent ||
72             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
73           return _.diag(SPV_ERROR_INVALID_ID, inst)
74                  << opcode_name << " Constituent <id> '"
75                  << _.getIdName(constituent_id)
76                  << "' is not a constant or undef.";
77         }
78         const auto constituent_result_type = _.FindDef(constituent->type_id());
79         if (!constituent_result_type ||
80             component_type->opcode() != constituent_result_type->opcode()) {
81           return _.diag(SPV_ERROR_INVALID_ID, inst)
82                  << opcode_name << " Constituent <id> '"
83                  << _.getIdName(constituent_id)
84                  << "'s type does not match Result Type <id> '"
85                  << _.getIdName(result_type->id()) << "'s vector element type.";
86         }
87       }
88     } break;
89     case SpvOpTypeMatrix: {
90       const auto column_count = result_type->GetOperandAs<uint32_t>(2);
91       if (column_count != constituent_count) {
92         // TODO: Output ID's on diagnostic
93         return _.diag(SPV_ERROR_INVALID_ID, inst)
94                << opcode_name
95                << " Constituent <id> count does not match "
96                   "Result Type <id> '"
97                << _.getIdName(result_type->id()) << "'s matrix column count.";
98       }
99 
100       const auto column_type = _.FindDef(result_type->words()[2]);
101       if (!column_type) {
102         return _.diag(SPV_ERROR_INVALID_ID, result_type)
103                << "Column type is not defined.";
104       }
105       const auto component_count = column_type->GetOperandAs<uint32_t>(2);
106       const auto component_type =
107           _.FindDef(column_type->GetOperandAs<uint32_t>(1));
108       if (!component_type) {
109         return _.diag(SPV_ERROR_INVALID_ID, column_type)
110                << "Component type is not defined.";
111       }
112 
113       for (size_t constituent_index = 2;
114            constituent_index < inst->operands().size(); constituent_index++) {
115         const auto constituent_id =
116             inst->GetOperandAs<uint32_t>(constituent_index);
117         const auto constituent = _.FindDef(constituent_id);
118         if (!constituent ||
119             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
120           // The message says "... or undef" because the spec does not say
121           // undef is a constant.
122           return _.diag(SPV_ERROR_INVALID_ID, inst)
123                  << opcode_name << " Constituent <id> '"
124                  << _.getIdName(constituent_id)
125                  << "' is not a constant or undef.";
126         }
127         const auto vector = _.FindDef(constituent->type_id());
128         if (!vector) {
129           return _.diag(SPV_ERROR_INVALID_ID, constituent)
130                  << "Result type is not defined.";
131         }
132         if (column_type->opcode() != vector->opcode()) {
133           return _.diag(SPV_ERROR_INVALID_ID, inst)
134                  << opcode_name << " Constituent <id> '"
135                  << _.getIdName(constituent_id)
136                  << "' type does not match Result Type <id> '"
137                  << _.getIdName(result_type->id()) << "'s matrix column type.";
138         }
139         const auto vector_component_type =
140             _.FindDef(vector->GetOperandAs<uint32_t>(1));
141         if (component_type->id() != vector_component_type->id()) {
142           return _.diag(SPV_ERROR_INVALID_ID, inst)
143                  << opcode_name << " Constituent <id> '"
144                  << _.getIdName(constituent_id)
145                  << "' component type does not match Result Type <id> '"
146                  << _.getIdName(result_type->id())
147                  << "'s matrix column component type.";
148         }
149         if (component_count != vector->words()[3]) {
150           return _.diag(SPV_ERROR_INVALID_ID, inst)
151                  << opcode_name << " Constituent <id> '"
152                  << _.getIdName(constituent_id)
153                  << "' vector component count does not match Result Type <id> '"
154                  << _.getIdName(result_type->id())
155                  << "'s vector component count.";
156         }
157       }
158     } break;
159     case SpvOpTypeArray: {
160       auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
161       if (!element_type) {
162         return _.diag(SPV_ERROR_INVALID_ID, result_type)
163                << "Element type is not defined.";
164       }
165       const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
166       if (!length) {
167         return _.diag(SPV_ERROR_INVALID_ID, result_type)
168                << "Length is not defined.";
169       }
170       bool is_int32;
171       bool is_const;
172       uint32_t value;
173       std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
174       if (is_int32 && is_const && value != constituent_count) {
175         return _.diag(SPV_ERROR_INVALID_ID, inst)
176                << opcode_name
177                << " Constituent count does not match "
178                   "Result Type <id> '"
179                << _.getIdName(result_type->id()) << "'s array length.";
180       }
181       for (size_t constituent_index = 2;
182            constituent_index < inst->operands().size(); constituent_index++) {
183         const auto constituent_id =
184             inst->GetOperandAs<uint32_t>(constituent_index);
185         const auto constituent = _.FindDef(constituent_id);
186         if (!constituent ||
187             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
188           return _.diag(SPV_ERROR_INVALID_ID, inst)
189                  << opcode_name << " Constituent <id> '"
190                  << _.getIdName(constituent_id)
191                  << "' is not a constant or undef.";
192         }
193         const auto constituent_type = _.FindDef(constituent->type_id());
194         if (!constituent_type) {
195           return _.diag(SPV_ERROR_INVALID_ID, constituent)
196                  << "Result type is not defined.";
197         }
198         if (element_type->id() != constituent_type->id()) {
199           return _.diag(SPV_ERROR_INVALID_ID, inst)
200                  << opcode_name << " Constituent <id> '"
201                  << _.getIdName(constituent_id)
202                  << "'s type does not match Result Type <id> '"
203                  << _.getIdName(result_type->id()) << "'s array element type.";
204         }
205       }
206     } break;
207     case SpvOpTypeStruct: {
208       const auto member_count = result_type->words().size() - 2;
209       if (member_count != constituent_count) {
210         return _.diag(SPV_ERROR_INVALID_ID, inst)
211                << opcode_name << " Constituent <id> '"
212                << _.getIdName(inst->type_id())
213                << "' count does not match Result Type <id> '"
214                << _.getIdName(result_type->id()) << "'s struct member count.";
215       }
216       for (uint32_t constituent_index = 2, member_index = 1;
217            constituent_index < inst->operands().size();
218            constituent_index++, member_index++) {
219         const auto constituent_id =
220             inst->GetOperandAs<uint32_t>(constituent_index);
221         const auto constituent = _.FindDef(constituent_id);
222         if (!constituent ||
223             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
224           return _.diag(SPV_ERROR_INVALID_ID, inst)
225                  << opcode_name << " Constituent <id> '"
226                  << _.getIdName(constituent_id)
227                  << "' is not a constant or undef.";
228         }
229         const auto constituent_type = _.FindDef(constituent->type_id());
230         if (!constituent_type) {
231           return _.diag(SPV_ERROR_INVALID_ID, constituent)
232                  << "Result type is not defined.";
233         }
234 
235         const auto member_type_id =
236             result_type->GetOperandAs<uint32_t>(member_index);
237         const auto member_type = _.FindDef(member_type_id);
238         if (!member_type || member_type->id() != constituent_type->id()) {
239           return _.diag(SPV_ERROR_INVALID_ID, inst)
240                  << opcode_name << " Constituent <id> '"
241                  << _.getIdName(constituent_id)
242                  << "' type does not match the Result Type <id> '"
243                  << _.getIdName(result_type->id()) << "'s member type.";
244         }
245       }
246     } break;
247     case SpvOpTypeCooperativeMatrixNV: {
248       if (1 != constituent_count) {
249         return _.diag(SPV_ERROR_INVALID_ID, inst)
250                << opcode_name << " Constituent <id> '"
251                << _.getIdName(inst->type_id()) << "' count must be one.";
252       }
253       const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
254       const auto constituent = _.FindDef(constituent_id);
255       if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
256         return _.diag(SPV_ERROR_INVALID_ID, inst)
257                << opcode_name << " Constituent <id> '"
258                << _.getIdName(constituent_id)
259                << "' is not a constant or undef.";
260       }
261       const auto constituent_type = _.FindDef(constituent->type_id());
262       if (!constituent_type) {
263         return _.diag(SPV_ERROR_INVALID_ID, constituent)
264                << "Result type is not defined.";
265       }
266 
267       const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
268       const auto component_type = _.FindDef(component_type_id);
269       if (!component_type || component_type->id() != constituent_type->id()) {
270         return _.diag(SPV_ERROR_INVALID_ID, inst)
271                << opcode_name << " Constituent <id> '"
272                << _.getIdName(constituent_id)
273                << "' type does not match the Result Type <id> '"
274                << _.getIdName(result_type->id()) << "'s component type.";
275       }
276     } break;
277     default:
278       break;
279   }
280   return SPV_SUCCESS;
281 }
282 
ValidateConstantSampler(ValidationState_t & _,const Instruction * inst)283 spv_result_t ValidateConstantSampler(ValidationState_t& _,
284                                      const Instruction* inst) {
285   const auto result_type = _.FindDef(inst->type_id());
286   if (!result_type || result_type->opcode() != SpvOpTypeSampler) {
287     return _.diag(SPV_ERROR_INVALID_ID, result_type)
288            << "OpConstantSampler Result Type <id> '"
289            << _.getIdName(inst->type_id()) << "' is not a sampler type.";
290   }
291 
292   return SPV_SUCCESS;
293 }
294 
295 // True if instruction defines a type that can have a null value, as defined by
296 // the SPIR-V spec.  Tracks composite-type components through module to check
297 // nullability transitively.
IsTypeNullable(const std::vector<uint32_t> & instruction,const ValidationState_t & _)298 bool IsTypeNullable(const std::vector<uint32_t>& instruction,
299                     const ValidationState_t& _) {
300   uint16_t opcode;
301   uint16_t word_count;
302   spvOpcodeSplit(instruction[0], &word_count, &opcode);
303   switch (static_cast<SpvOp>(opcode)) {
304     case SpvOpTypeBool:
305     case SpvOpTypeInt:
306     case SpvOpTypeFloat:
307     case SpvOpTypeEvent:
308     case SpvOpTypeDeviceEvent:
309     case SpvOpTypeReserveId:
310     case SpvOpTypeQueue:
311       return true;
312     case SpvOpTypeArray:
313     case SpvOpTypeMatrix:
314     case SpvOpTypeCooperativeMatrixNV:
315     case SpvOpTypeVector: {
316       auto base_type = _.FindDef(instruction[2]);
317       return base_type && IsTypeNullable(base_type->words(), _);
318     }
319     case SpvOpTypeStruct: {
320       for (size_t elementIndex = 2; elementIndex < instruction.size();
321            ++elementIndex) {
322         auto element = _.FindDef(instruction[elementIndex]);
323         if (!element || !IsTypeNullable(element->words(), _)) return false;
324       }
325       return true;
326     }
327     case SpvOpTypePointer:
328       if (instruction[2] == SpvStorageClassPhysicalStorageBuffer) {
329         return false;
330       }
331       return true;
332     default:
333       return false;
334   }
335 }
336 
ValidateConstantNull(ValidationState_t & _,const Instruction * inst)337 spv_result_t ValidateConstantNull(ValidationState_t& _,
338                                   const Instruction* inst) {
339   const auto result_type = _.FindDef(inst->type_id());
340   if (!result_type || !IsTypeNullable(result_type->words(), _)) {
341     return _.diag(SPV_ERROR_INVALID_ID, inst)
342            << "OpConstantNull Result Type <id> '"
343            << _.getIdName(inst->type_id()) << "' cannot have a null value.";
344   }
345 
346   return SPV_SUCCESS;
347 }
348 
349 // Validates that OpSpecConstant specializes to either int or float type.
ValidateSpecConstant(ValidationState_t & _,const Instruction * inst)350 spv_result_t ValidateSpecConstant(ValidationState_t& _,
351                                   const Instruction* inst) {
352   // Operand 0 is the <id> of the type that we're specializing to.
353   auto type_id = inst->GetOperandAs<const uint32_t>(0);
354   auto type_instruction = _.FindDef(type_id);
355   auto type_opcode = type_instruction->opcode();
356   if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
357     return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
358                                                    "must be an integer or "
359                                                    "floating-point number.";
360   }
361   return SPV_SUCCESS;
362 }
363 
ValidateSpecConstantOp(ValidationState_t & _,const Instruction * inst)364 spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
365                                     const Instruction* inst) {
366   const auto op = inst->GetOperandAs<SpvOp>(2);
367 
368   // The binary parser already ensures that the op is valid for *some*
369   // environment.  Here we check restrictions.
370   switch (op) {
371     case SpvOpQuantizeToF16:
372       if (!_.HasCapability(SpvCapabilityShader)) {
373         return _.diag(SPV_ERROR_INVALID_ID, inst)
374                << "Specialization constant operation " << spvOpcodeString(op)
375                << " requires Shader capability";
376       }
377       break;
378 
379     case SpvOpUConvert:
380       if (!_.features().uconvert_spec_constant_op &&
381           !_.HasCapability(SpvCapabilityKernel)) {
382         return _.diag(SPV_ERROR_INVALID_ID, inst)
383                << "Prior to SPIR-V 1.4, specialization constant operation "
384                   "UConvert requires Kernel capability or extension "
385                   "SPV_AMD_gpu_shader_int16";
386       }
387       break;
388 
389     case SpvOpConvertFToS:
390     case SpvOpConvertSToF:
391     case SpvOpConvertFToU:
392     case SpvOpConvertUToF:
393     case SpvOpConvertPtrToU:
394     case SpvOpConvertUToPtr:
395     case SpvOpGenericCastToPtr:
396     case SpvOpPtrCastToGeneric:
397     case SpvOpBitcast:
398     case SpvOpFNegate:
399     case SpvOpFAdd:
400     case SpvOpFSub:
401     case SpvOpFMul:
402     case SpvOpFDiv:
403     case SpvOpFRem:
404     case SpvOpFMod:
405     case SpvOpAccessChain:
406     case SpvOpInBoundsAccessChain:
407     case SpvOpPtrAccessChain:
408     case SpvOpInBoundsPtrAccessChain:
409       if (!_.HasCapability(SpvCapabilityKernel)) {
410         return _.diag(SPV_ERROR_INVALID_ID, inst)
411                << "Specialization constant operation " << spvOpcodeString(op)
412                << " requires Kernel capability";
413       }
414       break;
415 
416     default:
417       break;
418   }
419 
420   // TODO(dneto): Validate result type and arguments to the various operations.
421   return SPV_SUCCESS;
422 }
423 
424 }  // namespace
425 
ConstantPass(ValidationState_t & _,const Instruction * inst)426 spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
427   switch (inst->opcode()) {
428     case SpvOpConstantTrue:
429     case SpvOpConstantFalse:
430     case SpvOpSpecConstantTrue:
431     case SpvOpSpecConstantFalse:
432       if (auto error = ValidateConstantBool(_, inst)) return error;
433       break;
434     case SpvOpConstantComposite:
435     case SpvOpSpecConstantComposite:
436       if (auto error = ValidateConstantComposite(_, inst)) return error;
437       break;
438     case SpvOpConstantSampler:
439       if (auto error = ValidateConstantSampler(_, inst)) return error;
440       break;
441     case SpvOpConstantNull:
442       if (auto error = ValidateConstantNull(_, inst)) return error;
443       break;
444     case SpvOpSpecConstant:
445       if (auto error = ValidateSpecConstant(_, inst)) return error;
446       break;
447     case SpvOpSpecConstantOp:
448       if (auto error = ValidateSpecConstantOp(_, inst)) return error;
449       break;
450     default:
451       break;
452   }
453 
454   // Generally disallow creating 8- or 16-bit constants unless the full
455   // capabilities are present.
456   if (spvOpcodeIsConstant(inst->opcode()) &&
457       _.HasCapability(SpvCapabilityShader) &&
458       !_.IsPointerType(inst->type_id()) &&
459       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
460     return _.diag(SPV_ERROR_INVALID_ID, inst)
461            << "Cannot form constants of 8- or 16-bit types";
462   }
463 
464   return SPV_SUCCESS;
465 }
466 
467 }  // namespace val
468 }  // namespace spvtools
469