1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 #include <vector>
17 
18 #include "source/diagnostic.h"
19 #include "source/spirv_constant.h"
20 #include "source/spirv_target_env.h"
21 #include "source/val/function.h"
22 #include "source/val/instruction.h"
23 #include "source/val/validate.h"
24 #include "source/val/validation_state.h"
25 
26 namespace spvtools {
27 namespace val {
28 namespace {
29 
30 // Returns true if \c inst is an input or output variable.
is_interface_variable(const Instruction * inst,bool is_spv_1_4)31 bool is_interface_variable(const Instruction* inst, bool is_spv_1_4) {
32   if (is_spv_1_4) {
33     // Starting in SPIR-V 1.4, all global variables are interface variables.
34     return inst->opcode() == SpvOpVariable &&
35            inst->word(3u) != SpvStorageClassFunction;
36   } else {
37     return inst->opcode() == SpvOpVariable &&
38            (inst->word(3u) == SpvStorageClassInput ||
39             inst->word(3u) == SpvStorageClassOutput);
40   }
41 }
42 
43 // Checks that \c var is listed as an interface in all the entry points that use
44 // it.
check_interface_variable(ValidationState_t & _,const Instruction * var)45 spv_result_t check_interface_variable(ValidationState_t& _,
46                                       const Instruction* var) {
47   std::vector<const Function*> functions;
48   std::vector<const Instruction*> uses;
49   for (auto use : var->uses()) {
50     uses.push_back(use.first);
51   }
52   for (uint32_t i = 0; i < uses.size(); ++i) {
53     const auto user = uses[i];
54     if (const Function* func = user->function()) {
55       functions.push_back(func);
56     } else {
57       // In the rare case that the variable is used by another instruction in
58       // the global scope, continue searching for an instruction used in a
59       // function.
60       for (auto use : user->uses()) {
61         uses.push_back(use.first);
62       }
63     }
64   }
65 
66   std::sort(functions.begin(), functions.end(),
67             [](const Function* lhs, const Function* rhs) {
68               return lhs->id() < rhs->id();
69             });
70   functions.erase(std::unique(functions.begin(), functions.end()),
71                   functions.end());
72 
73   std::vector<uint32_t> entry_points;
74   for (const auto func : functions) {
75     for (auto id : _.FunctionEntryPoints(func->id())) {
76       entry_points.push_back(id);
77     }
78   }
79 
80   std::sort(entry_points.begin(), entry_points.end());
81   entry_points.erase(std::unique(entry_points.begin(), entry_points.end()),
82                      entry_points.end());
83 
84   for (auto id : entry_points) {
85     for (const auto& desc : _.entry_point_descriptions(id)) {
86       bool found = false;
87       for (auto interface : desc.interfaces) {
88         if (var->id() == interface) {
89           found = true;
90           break;
91         }
92       }
93       if (!found) {
94         return _.diag(SPV_ERROR_INVALID_ID, var)
95                << "Interface variable id <" << var->id()
96                << "> is used by entry point '" << desc.name << "' id <" << id
97                << ">, but is not listed as an interface";
98       }
99     }
100   }
101 
102   return SPV_SUCCESS;
103 }
104 
105 // This function assumes a base location has been determined already. As such
106 // any further location decorations are invalid.
107 // TODO: if this code turns out to be slow, there is an opportunity to cache
108 // the result for a given type id.
NumConsumedLocations(ValidationState_t & _,const Instruction * type,uint32_t * num_locations)109 spv_result_t NumConsumedLocations(ValidationState_t& _, const Instruction* type,
110                                   uint32_t* num_locations) {
111   *num_locations = 0;
112   switch (type->opcode()) {
113     case SpvOpTypeInt:
114     case SpvOpTypeFloat:
115       // Scalars always consume a single location.
116       *num_locations = 1;
117       break;
118     case SpvOpTypeVector:
119       // 3- and 4-component 64-bit vectors consume two locations.
120       if ((_.ContainsSizedIntOrFloatType(type->id(), SpvOpTypeInt, 64) ||
121            _.ContainsSizedIntOrFloatType(type->id(), SpvOpTypeFloat, 64)) &&
122           (type->GetOperandAs<uint32_t>(2) > 2)) {
123         *num_locations = 2;
124       } else {
125         *num_locations = 1;
126       }
127       break;
128     case SpvOpTypeMatrix:
129       // Matrices consume locations equal to the underlying vector type for
130       // each column.
131       NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
132                            num_locations);
133       *num_locations *= type->GetOperandAs<uint32_t>(2);
134       break;
135     case SpvOpTypeArray: {
136       // Arrays consume locations equal to the underlying type times the number
137       // of elements in the vector.
138       NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
139                            num_locations);
140       bool is_int = false;
141       bool is_const = false;
142       uint32_t value = 0;
143       // Attempt to evaluate the number of array elements.
144       std::tie(is_int, is_const, value) =
145           _.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
146       if (is_int && is_const) *num_locations *= value;
147       break;
148     }
149     case SpvOpTypeStruct: {
150       // Members cannot have location decorations at this point.
151       if (_.HasDecoration(type->id(), SpvDecorationLocation)) {
152         return _.diag(SPV_ERROR_INVALID_DATA, type)
153                << "Members cannot be assigned a location";
154       }
155 
156       // Structs consume locations equal to the sum of the locations consumed
157       // by the members.
158       for (uint32_t i = 1; i < type->operands().size(); ++i) {
159         uint32_t member_locations = 0;
160         if (auto error = NumConsumedLocations(
161                 _, _.FindDef(type->GetOperandAs<uint32_t>(i)),
162                 &member_locations)) {
163           return error;
164         }
165         *num_locations += member_locations;
166       }
167       break;
168     }
169     default:
170       break;
171   }
172 
173   return SPV_SUCCESS;
174 }
175 
176 // Returns the number of components consumed by types that support a component
177 // decoration.
NumConsumedComponents(ValidationState_t & _,const Instruction * type)178 uint32_t NumConsumedComponents(ValidationState_t& _, const Instruction* type) {
179   uint32_t num_components = 0;
180   switch (type->opcode()) {
181     case SpvOpTypeInt:
182     case SpvOpTypeFloat:
183       // 64-bit types consume two components.
184       if (type->GetOperandAs<uint32_t>(1) == 64) {
185         num_components = 2;
186       } else {
187         num_components = 1;
188       }
189       break;
190     case SpvOpTypeVector:
191       // Vectors consume components equal to the underlying type's consumption
192       // times the number of elements in the vector. Note that 3- and 4-element
193       // vectors cannot have a component decoration (i.e. assumed to be zero).
194       num_components =
195           NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
196       num_components *= type->GetOperandAs<uint32_t>(2);
197       break;
198     default:
199       // This is an error that is validated elsewhere.
200       break;
201   }
202 
203   return num_components;
204 }
205 
206 // Populates |locations| (and/or |output_index1_locations|) with the use
207 // location and component coordinates for |variable|. Indices are calculated as
208 // 4 * location + component.
GetLocationsForVariable(ValidationState_t & _,const Instruction * entry_point,const Instruction * variable,std::unordered_set<uint32_t> * locations,std::unordered_set<uint32_t> * output_index1_locations)209 spv_result_t GetLocationsForVariable(
210     ValidationState_t& _, const Instruction* entry_point,
211     const Instruction* variable, std::unordered_set<uint32_t>* locations,
212     std::unordered_set<uint32_t>* output_index1_locations) {
213   const bool is_fragment = entry_point->GetOperandAs<SpvExecutionModel>(0) ==
214                            SpvExecutionModelFragment;
215   const bool is_output =
216       variable->GetOperandAs<SpvStorageClass>(2) == SpvStorageClassOutput;
217   auto ptr_type_id = variable->GetOperandAs<uint32_t>(0);
218   auto ptr_type = _.FindDef(ptr_type_id);
219   auto type_id = ptr_type->GetOperandAs<uint32_t>(2);
220   auto type = _.FindDef(type_id);
221 
222   // Check for Location, Component and Index decorations on the variable. The
223   // validator allows duplicate decorations if the location/component/index are
224   // equal. Also track Patch and PerTaskNV decorations.
225   bool has_location = false;
226   uint32_t location = 0;
227   bool has_component = false;
228   uint32_t component = 0;
229   bool has_index = false;
230   uint32_t index = 0;
231   bool has_patch = false;
232   bool has_per_task_nv = false;
233   bool has_per_vertex_nv = false;
234   for (auto& dec : _.id_decorations(variable->id())) {
235     if (dec.dec_type() == SpvDecorationLocation) {
236       if (has_location && dec.params()[0] != location) {
237         return _.diag(SPV_ERROR_INVALID_DATA, variable)
238                << "Variable has conflicting location decorations";
239       }
240       has_location = true;
241       location = dec.params()[0];
242     } else if (dec.dec_type() == SpvDecorationComponent) {
243       if (has_component && dec.params()[0] != component) {
244         return _.diag(SPV_ERROR_INVALID_DATA, variable)
245                << "Variable has conflicting component decorations";
246       }
247       has_component = true;
248       component = dec.params()[0];
249     } else if (dec.dec_type() == SpvDecorationIndex) {
250       if (!is_output || !is_fragment) {
251         return _.diag(SPV_ERROR_INVALID_DATA, variable)
252                << "Index can only be applied to Fragment output variables";
253       }
254       if (has_index && dec.params()[0] != index) {
255         return _.diag(SPV_ERROR_INVALID_DATA, variable)
256                << "Variable has conflicting index decorations";
257       }
258       has_index = true;
259       index = dec.params()[0];
260     } else if (dec.dec_type() == SpvDecorationBuiltIn) {
261       // Don't check built-ins.
262       return SPV_SUCCESS;
263     } else if (dec.dec_type() == SpvDecorationPatch) {
264       has_patch = true;
265     } else if (dec.dec_type() == SpvDecorationPerTaskNV) {
266       has_per_task_nv = true;
267     } else if (dec.dec_type() == SpvDecorationPerVertexNV) {
268       has_per_vertex_nv = true;
269     }
270   }
271 
272   // Vulkan 14.1.3: Tessellation control and mesh per-vertex outputs and
273   // tessellation control, evaluation and geometry per-vertex inputs have a
274   // layer of arraying that is not included in interface matching.
275   bool is_arrayed = false;
276   switch (entry_point->GetOperandAs<SpvExecutionModel>(0)) {
277     case SpvExecutionModelTessellationControl:
278       if (!has_patch) {
279         is_arrayed = true;
280       }
281       break;
282     case SpvExecutionModelTessellationEvaluation:
283       if (!is_output && !has_patch) {
284         is_arrayed = true;
285       }
286       break;
287     case SpvExecutionModelGeometry:
288       if (!is_output) {
289         is_arrayed = true;
290       }
291       break;
292     case SpvExecutionModelFragment:
293       if (!is_output && has_per_vertex_nv) {
294         is_arrayed = true;
295       }
296       break;
297     case SpvExecutionModelMeshNV:
298       if (is_output && !has_per_task_nv) {
299         is_arrayed = true;
300       }
301       break;
302     default:
303       break;
304   }
305 
306   // Unpack arrayness.
307   if (is_arrayed && (type->opcode() == SpvOpTypeArray ||
308                      type->opcode() == SpvOpTypeRuntimeArray)) {
309     type_id = type->GetOperandAs<uint32_t>(1);
310     type = _.FindDef(type_id);
311   }
312 
313   if (type->opcode() == SpvOpTypeStruct) {
314     // Don't check built-ins.
315     if (_.HasDecoration(type_id, SpvDecorationBuiltIn)) return SPV_SUCCESS;
316   }
317 
318   // Only block-decorated structs don't need a location on the variable.
319   const bool is_block = _.HasDecoration(type_id, SpvDecorationBlock);
320   if (!has_location && !is_block) {
321     return _.diag(SPV_ERROR_INVALID_DATA, variable)
322            << "Variable must be decorated with a location";
323   }
324 
325   const std::string storage_class = is_output ? "output" : "input";
326   if (has_location) {
327     auto sub_type = type;
328     bool is_int = false;
329     bool is_const = false;
330     uint32_t array_size = 1;
331     // If the variable is still arrayed, mark the locations/components per
332     // index.
333     if (type->opcode() == SpvOpTypeArray) {
334       // Determine the array size if possible and get the element type.
335       std::tie(is_int, is_const, array_size) =
336           _.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
337       if (!is_int || !is_const) array_size = 1;
338       auto sub_type_id = type->GetOperandAs<uint32_t>(1);
339       sub_type = _.FindDef(sub_type_id);
340     }
341 
342     for (uint32_t array_idx = 0; array_idx < array_size; ++array_idx) {
343       uint32_t num_locations = 0;
344       if (auto error = NumConsumedLocations(_, sub_type, &num_locations))
345         return error;
346 
347       uint32_t num_components = NumConsumedComponents(_, sub_type);
348       uint32_t array_location = location + (num_locations * array_idx);
349       uint32_t start = array_location * 4;
350       uint32_t end = (array_location + num_locations) * 4;
351       if (num_components != 0) {
352         start += component;
353         end = array_location * 4 + component + num_components;
354       }
355 
356       auto locs = locations;
357       if (has_index && index == 1) locs = output_index1_locations;
358 
359       for (uint32_t i = start; i < end; ++i) {
360         if (!locs->insert(i).second) {
361           return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
362                  << "Entry-point has conflicting " << storage_class
363                  << " location assignment at location " << i / 4
364                  << ", component " << i % 4;
365         }
366       }
367     }
368   } else {
369     // For Block-decorated structs with no location assigned to the variable,
370     // each member of the block must be assigned a location. Also record any
371     // member component assignments. The validator allows duplicate decorations
372     // if they agree on the location/component.
373     std::unordered_map<uint32_t, uint32_t> member_locations;
374     std::unordered_map<uint32_t, uint32_t> member_components;
375     for (auto& dec : _.id_decorations(type_id)) {
376       if (dec.dec_type() == SpvDecorationLocation) {
377         auto where = member_locations.find(dec.struct_member_index());
378         if (where == member_locations.end()) {
379           member_locations[dec.struct_member_index()] = dec.params()[0];
380         } else if (where->second != dec.params()[0]) {
381           return _.diag(SPV_ERROR_INVALID_DATA, type)
382                  << "Member index " << dec.struct_member_index()
383                  << " has conflicting location assignments";
384         }
385       } else if (dec.dec_type() == SpvDecorationComponent) {
386         auto where = member_components.find(dec.struct_member_index());
387         if (where == member_components.end()) {
388           member_components[dec.struct_member_index()] = dec.params()[0];
389         } else if (where->second != dec.params()[0]) {
390           return _.diag(SPV_ERROR_INVALID_DATA, type)
391                  << "Member index " << dec.struct_member_index()
392                  << " has conflicting component assignments";
393         }
394       }
395     }
396 
397     for (uint32_t i = 1; i < type->operands().size(); ++i) {
398       auto where = member_locations.find(i - 1);
399       if (where == member_locations.end()) {
400         return _.diag(SPV_ERROR_INVALID_DATA, type)
401                << "Member index " << i - 1
402                << " is missing a location assignment";
403       }
404 
405       location = where->second;
406       auto member = _.FindDef(type->GetOperandAs<uint32_t>(i));
407       uint32_t num_locations = 0;
408       if (auto error = NumConsumedLocations(_, member, &num_locations))
409         return error;
410 
411       // If the component is not specified, it is assumed to be zero.
412       uint32_t num_components = NumConsumedComponents(_, member);
413       component = 0;
414       if (member_components.count(i - 1)) {
415         component = member_components[i - 1];
416       }
417 
418       uint32_t start = location * 4;
419       uint32_t end = (location + num_locations) * 4;
420       if (num_components != 0) {
421         start += component;
422         end = location * 4 + component + num_components;
423       }
424       for (uint32_t l = start; l < end; ++l) {
425         if (!locations->insert(l).second) {
426           return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
427                  << "Entry-point has conflicting " << storage_class
428                  << " location assignment at location " << l / 4
429                  << ", component " << l % 4;
430         }
431       }
432     }
433   }
434 
435   return SPV_SUCCESS;
436 }
437 
ValidateLocations(ValidationState_t & _,const Instruction * entry_point)438 spv_result_t ValidateLocations(ValidationState_t& _,
439                                const Instruction* entry_point) {
440   // According to Vulkan 14.1 only the following execution models have
441   // locations assigned.
442   switch (entry_point->GetOperandAs<SpvExecutionModel>(0)) {
443     case SpvExecutionModelVertex:
444     case SpvExecutionModelTessellationControl:
445     case SpvExecutionModelTessellationEvaluation:
446     case SpvExecutionModelGeometry:
447     case SpvExecutionModelFragment:
448       break;
449     default:
450       return SPV_SUCCESS;
451   }
452 
453   // Locations are stored as a combined location and component values.
454   std::unordered_set<uint32_t> input_locations;
455   std::unordered_set<uint32_t> output_locations_index0;
456   std::unordered_set<uint32_t> output_locations_index1;
457   for (uint32_t i = 3; i < entry_point->operands().size(); ++i) {
458     auto interface_id = entry_point->GetOperandAs<uint32_t>(i);
459     auto interface_var = _.FindDef(interface_id);
460     auto storage_class = interface_var->GetOperandAs<SpvStorageClass>(2);
461     if (storage_class != SpvStorageClassInput &&
462         storage_class != SpvStorageClassOutput) {
463       continue;
464     }
465 
466     auto locations = (storage_class == SpvStorageClassInput)
467                          ? &input_locations
468                          : &output_locations_index0;
469     if (auto error = GetLocationsForVariable(
470             _, entry_point, interface_var, locations, &output_locations_index1))
471       return error;
472   }
473 
474   return SPV_SUCCESS;
475 }
476 
477 }  // namespace
478 
ValidateInterfaces(ValidationState_t & _)479 spv_result_t ValidateInterfaces(ValidationState_t& _) {
480   bool is_spv_1_4 = _.version() >= SPV_SPIRV_VERSION_WORD(1, 4);
481   for (auto& inst : _.ordered_instructions()) {
482     if (is_interface_variable(&inst, is_spv_1_4)) {
483       if (auto error = check_interface_variable(_, &inst)) {
484         return error;
485       }
486     }
487   }
488 
489   if (spvIsVulkanEnv(_.context()->target_env)) {
490     for (auto& inst : _.ordered_instructions()) {
491       if (inst.opcode() == SpvOpEntryPoint) {
492         if (auto error = ValidateLocations(_, &inst)) {
493           return error;
494         }
495       }
496       if (inst.opcode() == SpvOpTypeVoid) break;
497     }
498   }
499 
500   return SPV_SUCCESS;
501 }
502 
503 }  // namespace val
504 }  // namespace spvtools
505