1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 #include <vector>
17 
18 #include "source/diagnostic.h"
19 #include "source/spirv_constant.h"
20 #include "source/spirv_target_env.h"
21 #include "source/val/function.h"
22 #include "source/val/instruction.h"
23 #include "source/val/validate.h"
24 #include "source/val/validation_state.h"
25 
26 namespace spvtools {
27 namespace val {
28 namespace {
29 
30 // Limit the number of checked locations to 4096. Multiplied by 4 to represent
31 // all the components. This limit is set to be well beyond practical use cases.
32 const uint32_t kMaxLocations = 4096 * 4;
33 
34 // Returns true if \c inst is an input or output variable.
is_interface_variable(const Instruction * inst,bool is_spv_1_4)35 bool is_interface_variable(const Instruction* inst, bool is_spv_1_4) {
36   if (is_spv_1_4) {
37     // Starting in SPIR-V 1.4, all global variables are interface variables.
38     return inst->opcode() == SpvOpVariable &&
39            inst->word(3u) != SpvStorageClassFunction;
40   } else {
41     return inst->opcode() == SpvOpVariable &&
42            (inst->word(3u) == SpvStorageClassInput ||
43             inst->word(3u) == SpvStorageClassOutput);
44   }
45 }
46 
47 // Checks that \c var is listed as an interface in all the entry points that use
48 // it.
check_interface_variable(ValidationState_t & _,const Instruction * var)49 spv_result_t check_interface_variable(ValidationState_t& _,
50                                       const Instruction* var) {
51   std::vector<const Function*> functions;
52   std::vector<const Instruction*> uses;
53   for (auto use : var->uses()) {
54     uses.push_back(use.first);
55   }
56   for (uint32_t i = 0; i < uses.size(); ++i) {
57     const auto user = uses[i];
58     if (const Function* func = user->function()) {
59       functions.push_back(func);
60     } else {
61       // In the rare case that the variable is used by another instruction in
62       // the global scope, continue searching for an instruction used in a
63       // function.
64       for (auto use : user->uses()) {
65         uses.push_back(use.first);
66       }
67     }
68   }
69 
70   std::sort(functions.begin(), functions.end(),
71             [](const Function* lhs, const Function* rhs) {
72               return lhs->id() < rhs->id();
73             });
74   functions.erase(std::unique(functions.begin(), functions.end()),
75                   functions.end());
76 
77   std::vector<uint32_t> entry_points;
78   for (const auto func : functions) {
79     for (auto id : _.FunctionEntryPoints(func->id())) {
80       entry_points.push_back(id);
81     }
82   }
83 
84   std::sort(entry_points.begin(), entry_points.end());
85   entry_points.erase(std::unique(entry_points.begin(), entry_points.end()),
86                      entry_points.end());
87 
88   for (auto id : entry_points) {
89     for (const auto& desc : _.entry_point_descriptions(id)) {
90       bool found = false;
91       for (auto interface : desc.interfaces) {
92         if (var->id() == interface) {
93           found = true;
94           break;
95         }
96       }
97       if (!found) {
98         return _.diag(SPV_ERROR_INVALID_ID, var)
99                << "Interface variable id <" << var->id()
100                << "> is used by entry point '" << desc.name << "' id <" << id
101                << ">, but is not listed as an interface";
102       }
103     }
104   }
105 
106   return SPV_SUCCESS;
107 }
108 
109 // This function assumes a base location has been determined already. As such
110 // any further location decorations are invalid.
111 // TODO: if this code turns out to be slow, there is an opportunity to cache
112 // the result for a given type id.
NumConsumedLocations(ValidationState_t & _,const Instruction * type,uint32_t * num_locations)113 spv_result_t NumConsumedLocations(ValidationState_t& _, const Instruction* type,
114                                   uint32_t* num_locations) {
115   *num_locations = 0;
116   switch (type->opcode()) {
117     case SpvOpTypeInt:
118     case SpvOpTypeFloat:
119       // Scalars always consume a single location.
120       *num_locations = 1;
121       break;
122     case SpvOpTypeVector:
123       // 3- and 4-component 64-bit vectors consume two locations.
124       if ((_.ContainsSizedIntOrFloatType(type->id(), SpvOpTypeInt, 64) ||
125            _.ContainsSizedIntOrFloatType(type->id(), SpvOpTypeFloat, 64)) &&
126           (type->GetOperandAs<uint32_t>(2) > 2)) {
127         *num_locations = 2;
128       } else {
129         *num_locations = 1;
130       }
131       break;
132     case SpvOpTypeMatrix:
133       // Matrices consume locations equal to the underlying vector type for
134       // each column.
135       NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
136                            num_locations);
137       *num_locations *= type->GetOperandAs<uint32_t>(2);
138       break;
139     case SpvOpTypeArray: {
140       // Arrays consume locations equal to the underlying type times the number
141       // of elements in the vector.
142       NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
143                            num_locations);
144       bool is_int = false;
145       bool is_const = false;
146       uint32_t value = 0;
147       // Attempt to evaluate the number of array elements.
148       std::tie(is_int, is_const, value) =
149           _.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
150       if (is_int && is_const) *num_locations *= value;
151       break;
152     }
153     case SpvOpTypeStruct: {
154       // Members cannot have location decorations at this point.
155       if (_.HasDecoration(type->id(), SpvDecorationLocation)) {
156         return _.diag(SPV_ERROR_INVALID_DATA, type)
157                << "Members cannot be assigned a location";
158       }
159 
160       // Structs consume locations equal to the sum of the locations consumed
161       // by the members.
162       for (uint32_t i = 1; i < type->operands().size(); ++i) {
163         uint32_t member_locations = 0;
164         if (auto error = NumConsumedLocations(
165                 _, _.FindDef(type->GetOperandAs<uint32_t>(i)),
166                 &member_locations)) {
167           return error;
168         }
169         *num_locations += member_locations;
170       }
171       break;
172     }
173     default:
174       break;
175   }
176 
177   return SPV_SUCCESS;
178 }
179 
180 // Returns the number of components consumed by types that support a component
181 // decoration.
NumConsumedComponents(ValidationState_t & _,const Instruction * type)182 uint32_t NumConsumedComponents(ValidationState_t& _, const Instruction* type) {
183   uint32_t num_components = 0;
184   switch (type->opcode()) {
185     case SpvOpTypeInt:
186     case SpvOpTypeFloat:
187       // 64-bit types consume two components.
188       if (type->GetOperandAs<uint32_t>(1) == 64) {
189         num_components = 2;
190       } else {
191         num_components = 1;
192       }
193       break;
194     case SpvOpTypeVector:
195       // Vectors consume components equal to the underlying type's consumption
196       // times the number of elements in the vector. Note that 3- and 4-element
197       // vectors cannot have a component decoration (i.e. assumed to be zero).
198       num_components =
199           NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
200       num_components *= type->GetOperandAs<uint32_t>(2);
201       break;
202     case SpvOpTypeArray:
203       // Skip the array.
204       return NumConsumedComponents(_,
205                                    _.FindDef(type->GetOperandAs<uint32_t>(1)));
206     default:
207       // This is an error that is validated elsewhere.
208       break;
209   }
210 
211   return num_components;
212 }
213 
214 // Populates |locations| (and/or |output_index1_locations|) with the use
215 // location and component coordinates for |variable|. Indices are calculated as
216 // 4 * location + component.
GetLocationsForVariable(ValidationState_t & _,const Instruction * entry_point,const Instruction * variable,std::unordered_set<uint32_t> * locations,std::unordered_set<uint32_t> * output_index1_locations)217 spv_result_t GetLocationsForVariable(
218     ValidationState_t& _, const Instruction* entry_point,
219     const Instruction* variable, std::unordered_set<uint32_t>* locations,
220     std::unordered_set<uint32_t>* output_index1_locations) {
221   const bool is_fragment = entry_point->GetOperandAs<SpvExecutionModel>(0) ==
222                            SpvExecutionModelFragment;
223   const bool is_output =
224       variable->GetOperandAs<SpvStorageClass>(2) == SpvStorageClassOutput;
225   auto ptr_type_id = variable->GetOperandAs<uint32_t>(0);
226   auto ptr_type = _.FindDef(ptr_type_id);
227   auto type_id = ptr_type->GetOperandAs<uint32_t>(2);
228   auto type = _.FindDef(type_id);
229 
230   // Check for Location, Component and Index decorations on the variable. The
231   // validator allows duplicate decorations if the location/component/index are
232   // equal. Also track Patch and PerTaskNV decorations.
233   bool has_location = false;
234   uint32_t location = 0;
235   bool has_component = false;
236   uint32_t component = 0;
237   bool has_index = false;
238   uint32_t index = 0;
239   bool has_patch = false;
240   bool has_per_task_nv = false;
241   bool has_per_vertex_nv = false;
242   for (auto& dec : _.id_decorations(variable->id())) {
243     if (dec.dec_type() == SpvDecorationLocation) {
244       if (has_location && dec.params()[0] != location) {
245         return _.diag(SPV_ERROR_INVALID_DATA, variable)
246                << "Variable has conflicting location decorations";
247       }
248       has_location = true;
249       location = dec.params()[0];
250     } else if (dec.dec_type() == SpvDecorationComponent) {
251       if (has_component && dec.params()[0] != component) {
252         return _.diag(SPV_ERROR_INVALID_DATA, variable)
253                << "Variable has conflicting component decorations";
254       }
255       has_component = true;
256       component = dec.params()[0];
257     } else if (dec.dec_type() == SpvDecorationIndex) {
258       if (!is_output || !is_fragment) {
259         return _.diag(SPV_ERROR_INVALID_DATA, variable)
260                << "Index can only be applied to Fragment output variables";
261       }
262       if (has_index && dec.params()[0] != index) {
263         return _.diag(SPV_ERROR_INVALID_DATA, variable)
264                << "Variable has conflicting index decorations";
265       }
266       has_index = true;
267       index = dec.params()[0];
268     } else if (dec.dec_type() == SpvDecorationBuiltIn) {
269       // Don't check built-ins.
270       return SPV_SUCCESS;
271     } else if (dec.dec_type() == SpvDecorationPatch) {
272       has_patch = true;
273     } else if (dec.dec_type() == SpvDecorationPerTaskNV) {
274       has_per_task_nv = true;
275     } else if (dec.dec_type() == SpvDecorationPerVertexNV) {
276       has_per_vertex_nv = true;
277     }
278   }
279 
280   // Vulkan 14.1.3: Tessellation control and mesh per-vertex outputs and
281   // tessellation control, evaluation and geometry per-vertex inputs have a
282   // layer of arraying that is not included in interface matching.
283   bool is_arrayed = false;
284   switch (entry_point->GetOperandAs<SpvExecutionModel>(0)) {
285     case SpvExecutionModelTessellationControl:
286       if (!has_patch) {
287         is_arrayed = true;
288       }
289       break;
290     case SpvExecutionModelTessellationEvaluation:
291       if (!is_output && !has_patch) {
292         is_arrayed = true;
293       }
294       break;
295     case SpvExecutionModelGeometry:
296       if (!is_output) {
297         is_arrayed = true;
298       }
299       break;
300     case SpvExecutionModelFragment:
301       if (!is_output && has_per_vertex_nv) {
302         is_arrayed = true;
303       }
304       break;
305     case SpvExecutionModelMeshNV:
306       if (is_output && !has_per_task_nv) {
307         is_arrayed = true;
308       }
309       break;
310     default:
311       break;
312   }
313 
314   // Unpack arrayness.
315   if (is_arrayed && (type->opcode() == SpvOpTypeArray ||
316                      type->opcode() == SpvOpTypeRuntimeArray)) {
317     type_id = type->GetOperandAs<uint32_t>(1);
318     type = _.FindDef(type_id);
319   }
320 
321   if (type->opcode() == SpvOpTypeStruct) {
322     // Don't check built-ins.
323     if (_.HasDecoration(type_id, SpvDecorationBuiltIn)) return SPV_SUCCESS;
324   }
325 
326   // Only block-decorated structs don't need a location on the variable.
327   const bool is_block = _.HasDecoration(type_id, SpvDecorationBlock);
328   if (!has_location && !is_block) {
329     return _.diag(SPV_ERROR_INVALID_DATA, variable)
330            << "Variable must be decorated with a location";
331   }
332 
333   const std::string storage_class = is_output ? "output" : "input";
334   if (has_location) {
335     auto sub_type = type;
336     bool is_int = false;
337     bool is_const = false;
338     uint32_t array_size = 1;
339     // If the variable is still arrayed, mark the locations/components per
340     // index.
341     if (type->opcode() == SpvOpTypeArray) {
342       // Determine the array size if possible and get the element type.
343       std::tie(is_int, is_const, array_size) =
344           _.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
345       if (!is_int || !is_const) array_size = 1;
346       auto sub_type_id = type->GetOperandAs<uint32_t>(1);
347       sub_type = _.FindDef(sub_type_id);
348     }
349 
350     for (uint32_t array_idx = 0; array_idx < array_size; ++array_idx) {
351       uint32_t num_locations = 0;
352       if (auto error = NumConsumedLocations(_, sub_type, &num_locations))
353         return error;
354 
355       uint32_t num_components = NumConsumedComponents(_, sub_type);
356       uint32_t array_location = location + (num_locations * array_idx);
357       uint32_t start = array_location * 4;
358       if (kMaxLocations <= start) {
359         // Too many locations, give up.
360         break;
361       }
362 
363       uint32_t end = (array_location + num_locations) * 4;
364       if (num_components != 0) {
365         start += component;
366         end = array_location * 4 + component + num_components;
367       }
368 
369       auto locs = locations;
370       if (has_index && index == 1) locs = output_index1_locations;
371 
372       for (uint32_t i = start; i < end; ++i) {
373         if (!locs->insert(i).second) {
374           return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
375                  << "Entry-point has conflicting " << storage_class
376                  << " location assignment at location " << i / 4
377                  << ", component " << i % 4;
378         }
379       }
380     }
381   } else {
382     // For Block-decorated structs with no location assigned to the variable,
383     // each member of the block must be assigned a location. Also record any
384     // member component assignments. The validator allows duplicate decorations
385     // if they agree on the location/component.
386     std::unordered_map<uint32_t, uint32_t> member_locations;
387     std::unordered_map<uint32_t, uint32_t> member_components;
388     for (auto& dec : _.id_decorations(type_id)) {
389       if (dec.dec_type() == SpvDecorationLocation) {
390         auto where = member_locations.find(dec.struct_member_index());
391         if (where == member_locations.end()) {
392           member_locations[dec.struct_member_index()] = dec.params()[0];
393         } else if (where->second != dec.params()[0]) {
394           return _.diag(SPV_ERROR_INVALID_DATA, type)
395                  << "Member index " << dec.struct_member_index()
396                  << " has conflicting location assignments";
397         }
398       } else if (dec.dec_type() == SpvDecorationComponent) {
399         auto where = member_components.find(dec.struct_member_index());
400         if (where == member_components.end()) {
401           member_components[dec.struct_member_index()] = dec.params()[0];
402         } else if (where->second != dec.params()[0]) {
403           return _.diag(SPV_ERROR_INVALID_DATA, type)
404                  << "Member index " << dec.struct_member_index()
405                  << " has conflicting component assignments";
406         }
407       }
408     }
409 
410     for (uint32_t i = 1; i < type->operands().size(); ++i) {
411       auto where = member_locations.find(i - 1);
412       if (where == member_locations.end()) {
413         return _.diag(SPV_ERROR_INVALID_DATA, type)
414                << "Member index " << i - 1
415                << " is missing a location assignment";
416       }
417 
418       location = where->second;
419       auto member = _.FindDef(type->GetOperandAs<uint32_t>(i));
420       uint32_t num_locations = 0;
421       if (auto error = NumConsumedLocations(_, member, &num_locations))
422         return error;
423 
424       // If the component is not specified, it is assumed to be zero.
425       uint32_t num_components = NumConsumedComponents(_, member);
426       component = 0;
427       if (member_components.count(i - 1)) {
428         component = member_components[i - 1];
429       }
430 
431       uint32_t start = location * 4;
432       if (kMaxLocations <= start) {
433         // Too many locations, give up.
434         continue;
435       }
436 
437       if (member->opcode() == SpvOpTypeArray && num_components >= 1 &&
438           num_components < 4) {
439         // When an array has an element that takes less than a location in
440         // size, calculate the used locations in a strided manner.
441         for (uint32_t l = location; l < num_locations + location; ++l) {
442           for (uint32_t c = component; c < component + num_components; ++c) {
443             uint32_t check = 4 * l + c;
444             if (!locations->insert(check).second) {
445               return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
446                      << "Entry-point has conflicting " << storage_class
447                      << " location assignment at location " << l
448                      << ", component " << c;
449             }
450           }
451         }
452       } else {
453         // TODO: There is a hole here is the member is an array of 3- or
454         // 4-element vectors of 64-bit types.
455         uint32_t end = (location + num_locations) * 4;
456         if (num_components != 0) {
457           start += component;
458           end = location * 4 + component + num_components;
459         }
460         for (uint32_t l = start; l < end; ++l) {
461           if (!locations->insert(l).second) {
462             return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
463                    << "Entry-point has conflicting " << storage_class
464                    << " location assignment at location " << l / 4
465                    << ", component " << l % 4;
466           }
467         }
468       }
469     }
470   }
471 
472   return SPV_SUCCESS;
473 }
474 
ValidateLocations(ValidationState_t & _,const Instruction * entry_point)475 spv_result_t ValidateLocations(ValidationState_t& _,
476                                const Instruction* entry_point) {
477   // According to Vulkan 14.1 only the following execution models have
478   // locations assigned.
479   switch (entry_point->GetOperandAs<SpvExecutionModel>(0)) {
480     case SpvExecutionModelVertex:
481     case SpvExecutionModelTessellationControl:
482     case SpvExecutionModelTessellationEvaluation:
483     case SpvExecutionModelGeometry:
484     case SpvExecutionModelFragment:
485       break;
486     default:
487       return SPV_SUCCESS;
488   }
489 
490   // Locations are stored as a combined location and component values.
491   std::unordered_set<uint32_t> input_locations;
492   std::unordered_set<uint32_t> output_locations_index0;
493   std::unordered_set<uint32_t> output_locations_index1;
494   std::unordered_set<uint32_t> seen;
495   for (uint32_t i = 3; i < entry_point->operands().size(); ++i) {
496     auto interface_id = entry_point->GetOperandAs<uint32_t>(i);
497     auto interface_var = _.FindDef(interface_id);
498     auto storage_class = interface_var->GetOperandAs<SpvStorageClass>(2);
499     if (storage_class != SpvStorageClassInput &&
500         storage_class != SpvStorageClassOutput) {
501       continue;
502     }
503     if (!seen.insert(interface_id).second) {
504       // Pre-1.4 an interface variable could be listed multiple times in an
505       // entry point. Validation for 1.4 or later is done elsewhere.
506       continue;
507     }
508 
509     auto locations = (storage_class == SpvStorageClassInput)
510                          ? &input_locations
511                          : &output_locations_index0;
512     if (auto error = GetLocationsForVariable(
513             _, entry_point, interface_var, locations, &output_locations_index1))
514       return error;
515   }
516 
517   return SPV_SUCCESS;
518 }
519 
520 }  // namespace
521 
ValidateInterfaces(ValidationState_t & _)522 spv_result_t ValidateInterfaces(ValidationState_t& _) {
523   bool is_spv_1_4 = _.version() >= SPV_SPIRV_VERSION_WORD(1, 4);
524   for (auto& inst : _.ordered_instructions()) {
525     if (is_interface_variable(&inst, is_spv_1_4)) {
526       if (auto error = check_interface_variable(_, &inst)) {
527         return error;
528       }
529     }
530   }
531 
532   if (spvIsVulkanEnv(_.context()->target_env)) {
533     for (auto& inst : _.ordered_instructions()) {
534       if (inst.opcode() == SpvOpEntryPoint) {
535         if (auto error = ValidateLocations(_, &inst)) {
536           return error;
537         }
538       }
539       if (inst.opcode() == SpvOpTypeVoid) break;
540     }
541   }
542 
543   return SPV_SUCCESS;
544 }
545 
546 }  // namespace val
547 }  // namespace spvtools
548