1 // Copyright (c) 2015-2016 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 #include <cassert>
17 #include <functional>
18 #include <iostream>
19 #include <iterator>
20 #include <map>
21 #include <string>
22 #include <tuple>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 
28 #include "source/cfa.h"
29 #include "source/opcode.h"
30 #include "source/spirv_target_env.h"
31 #include "source/spirv_validator_options.h"
32 #include "source/val/basic_block.h"
33 #include "source/val/construct.h"
34 #include "source/val/function.h"
35 #include "source/val/validate.h"
36 #include "source/val/validation_state.h"
37 
38 namespace spvtools {
39 namespace val {
40 namespace {
41 
ValidatePhi(ValidationState_t & _,const Instruction * inst)42 spv_result_t ValidatePhi(ValidationState_t& _, const Instruction* inst) {
43   auto block = inst->block();
44   size_t num_in_ops = inst->words().size() - 3;
45   if (num_in_ops % 2 != 0) {
46     return _.diag(SPV_ERROR_INVALID_ID, inst)
47            << "OpPhi does not have an equal number of incoming values and "
48               "basic blocks.";
49   }
50 
51   const Instruction* type_inst = _.FindDef(inst->type_id());
52   assert(type_inst);
53 
54   const SpvOp type_opcode = type_inst->opcode();
55   if (type_opcode == SpvOpTypePointer &&
56       _.addressing_model() == SpvAddressingModelLogical) {
57     if (!_.features().variable_pointers &&
58         !_.features().variable_pointers_storage_buffer) {
59       return _.diag(SPV_ERROR_INVALID_DATA, inst)
60              << "Using pointers with OpPhi requires capability "
61              << "VariablePointers or VariablePointersStorageBuffer";
62     }
63   }
64 
65   if (!_.options()->before_hlsl_legalization) {
66     if (type_opcode == SpvOpTypeSampledImage ||
67         (_.HasCapability(SpvCapabilityShader) &&
68          (type_opcode == SpvOpTypeImage || type_opcode == SpvOpTypeSampler))) {
69       return _.diag(SPV_ERROR_INVALID_ID, inst)
70              << "Result type cannot be Op" << spvOpcodeString(type_opcode);
71     }
72   }
73 
74   // Create a uniqued vector of predecessor ids for comparison against
75   // incoming values. OpBranchConditional %cond %label %label produces two
76   // predecessors in the CFG.
77   std::vector<uint32_t> pred_ids;
78   std::transform(block->predecessors()->begin(), block->predecessors()->end(),
79                  std::back_inserter(pred_ids),
80                  [](const BasicBlock* b) { return b->id(); });
81   std::sort(pred_ids.begin(), pred_ids.end());
82   pred_ids.erase(std::unique(pred_ids.begin(), pred_ids.end()), pred_ids.end());
83 
84   size_t num_edges = num_in_ops / 2;
85   if (num_edges != pred_ids.size()) {
86     return _.diag(SPV_ERROR_INVALID_ID, inst)
87            << "OpPhi's number of incoming blocks (" << num_edges
88            << ") does not match block's predecessor count ("
89            << block->predecessors()->size() << ").";
90   }
91 
92   std::unordered_set<uint32_t> observed_predecessors;
93 
94   for (size_t i = 3; i < inst->words().size(); ++i) {
95     auto inc_id = inst->word(i);
96     if (i % 2 == 1) {
97       // Incoming value type must match the phi result type.
98       auto inc_type_id = _.GetTypeId(inc_id);
99       if (inst->type_id() != inc_type_id) {
100         return _.diag(SPV_ERROR_INVALID_ID, inst)
101                << "OpPhi's result type <id> " << _.getIdName(inst->type_id())
102                << " does not match incoming value <id> " << _.getIdName(inc_id)
103                << " type <id> " << _.getIdName(inc_type_id) << ".";
104       }
105     } else {
106       if (_.GetIdOpcode(inc_id) != SpvOpLabel) {
107         return _.diag(SPV_ERROR_INVALID_ID, inst)
108                << "OpPhi's incoming basic block <id> " << _.getIdName(inc_id)
109                << " is not an OpLabel.";
110       }
111 
112       // Incoming basic block must be an immediate predecessor of the phi's
113       // block.
114       if (!std::binary_search(pred_ids.begin(), pred_ids.end(), inc_id)) {
115         return _.diag(SPV_ERROR_INVALID_ID, inst)
116                << "OpPhi's incoming basic block <id> " << _.getIdName(inc_id)
117                << " is not a predecessor of <id> " << _.getIdName(block->id())
118                << ".";
119       }
120 
121       // We must not have already seen this predecessor as one of the phi's
122       // operands.
123       if (observed_predecessors.count(inc_id) != 0) {
124         return _.diag(SPV_ERROR_INVALID_ID, inst)
125                << "OpPhi references incoming basic block <id> "
126                << _.getIdName(inc_id) << " multiple times.";
127       }
128 
129       // Note the fact that we have now observed this predecessor.
130       observed_predecessors.insert(inc_id);
131     }
132   }
133 
134   return SPV_SUCCESS;
135 }
136 
ValidateBranch(ValidationState_t & _,const Instruction * inst)137 spv_result_t ValidateBranch(ValidationState_t& _, const Instruction* inst) {
138   // target operands must be OpLabel
139   const auto id = inst->GetOperandAs<uint32_t>(0);
140   const auto target = _.FindDef(id);
141   if (!target || SpvOpLabel != target->opcode()) {
142     return _.diag(SPV_ERROR_INVALID_ID, inst)
143            << "'Target Label' operands for OpBranch must be the ID "
144               "of an OpLabel instruction";
145   }
146 
147   return SPV_SUCCESS;
148 }
149 
ValidateBranchConditional(ValidationState_t & _,const Instruction * inst)150 spv_result_t ValidateBranchConditional(ValidationState_t& _,
151                                        const Instruction* inst) {
152   // num_operands is either 3 or 5 --- if 5, the last two need to be literal
153   // integers
154   const auto num_operands = inst->operands().size();
155   if (num_operands != 3 && num_operands != 5) {
156     return _.diag(SPV_ERROR_INVALID_ID, inst)
157            << "OpBranchConditional requires either 3 or 5 parameters";
158   }
159 
160   // grab the condition operand and check that it is a bool
161   const auto cond_id = inst->GetOperandAs<uint32_t>(0);
162   const auto cond_op = _.FindDef(cond_id);
163   if (!cond_op || !cond_op->type_id() ||
164       !_.IsBoolScalarType(cond_op->type_id())) {
165     return _.diag(SPV_ERROR_INVALID_ID, inst) << "Condition operand for "
166                                                  "OpBranchConditional must be "
167                                                  "of boolean type";
168   }
169 
170   // target operands must be OpLabel
171   // note that we don't need to check that the target labels are in the same
172   // function,
173   // PerformCfgChecks already checks for that
174   const auto true_id = inst->GetOperandAs<uint32_t>(1);
175   const auto true_target = _.FindDef(true_id);
176   if (!true_target || SpvOpLabel != true_target->opcode()) {
177     return _.diag(SPV_ERROR_INVALID_ID, inst)
178            << "The 'True Label' operand for OpBranchConditional must be the "
179               "ID of an OpLabel instruction";
180   }
181 
182   const auto false_id = inst->GetOperandAs<uint32_t>(2);
183   const auto false_target = _.FindDef(false_id);
184   if (!false_target || SpvOpLabel != false_target->opcode()) {
185     return _.diag(SPV_ERROR_INVALID_ID, inst)
186            << "The 'False Label' operand for OpBranchConditional must be the "
187               "ID of an OpLabel instruction";
188   }
189 
190   return SPV_SUCCESS;
191 }
192 
ValidateSwitch(ValidationState_t & _,const Instruction * inst)193 spv_result_t ValidateSwitch(ValidationState_t& _, const Instruction* inst) {
194   const auto num_operands = inst->operands().size();
195   // At least two operands (selector, default), any more than that are
196   // literal/target.
197 
198   // target operands must be OpLabel
199   for (size_t i = 2; i < num_operands; i += 2) {
200     // literal, id
201     const auto id = inst->GetOperandAs<uint32_t>(i + 1);
202     const auto target = _.FindDef(id);
203     if (!target || SpvOpLabel != target->opcode()) {
204       return _.diag(SPV_ERROR_INVALID_ID, inst)
205              << "'Target Label' operands for OpSwitch must be IDs of an "
206                 "OpLabel instruction";
207     }
208   }
209 
210   return SPV_SUCCESS;
211 }
212 
ValidateReturnValue(ValidationState_t & _,const Instruction * inst)213 spv_result_t ValidateReturnValue(ValidationState_t& _,
214                                  const Instruction* inst) {
215   const auto value_id = inst->GetOperandAs<uint32_t>(0);
216   const auto value = _.FindDef(value_id);
217   if (!value || !value->type_id()) {
218     return _.diag(SPV_ERROR_INVALID_ID, inst)
219            << "OpReturnValue Value <id> '" << _.getIdName(value_id)
220            << "' does not represent a value.";
221   }
222   auto value_type = _.FindDef(value->type_id());
223   if (!value_type || SpvOpTypeVoid == value_type->opcode()) {
224     return _.diag(SPV_ERROR_INVALID_ID, inst)
225            << "OpReturnValue value's type <id> '"
226            << _.getIdName(value->type_id()) << "' is missing or void.";
227   }
228 
229   const bool uses_variable_pointer =
230       _.features().variable_pointers ||
231       _.features().variable_pointers_storage_buffer;
232 
233   if (_.addressing_model() == SpvAddressingModelLogical &&
234       SpvOpTypePointer == value_type->opcode() && !uses_variable_pointer &&
235       !_.options()->relax_logical_pointer) {
236     return _.diag(SPV_ERROR_INVALID_ID, inst)
237            << "OpReturnValue value's type <id> '"
238            << _.getIdName(value->type_id())
239            << "' is a pointer, which is invalid in the Logical addressing "
240               "model.";
241   }
242 
243   const auto function = inst->function();
244   const auto return_type = _.FindDef(function->GetResultTypeId());
245   if (!return_type || return_type->id() != value_type->id()) {
246     return _.diag(SPV_ERROR_INVALID_ID, inst)
247            << "OpReturnValue Value <id> '" << _.getIdName(value_id)
248            << "'s type does not match OpFunction's return type.";
249   }
250 
251   return SPV_SUCCESS;
252 }
253 
ValidateLoopMerge(ValidationState_t & _,const Instruction * inst)254 spv_result_t ValidateLoopMerge(ValidationState_t& _, const Instruction* inst) {
255   const auto merge_id = inst->GetOperandAs<uint32_t>(0);
256   const auto merge = _.FindDef(merge_id);
257   if (!merge || merge->opcode() != SpvOpLabel) {
258     return _.diag(SPV_ERROR_INVALID_ID, inst)
259            << "Merge Block " << _.getIdName(merge_id) << " must be an OpLabel";
260   }
261   if (merge_id == inst->block()->id()) {
262     return _.diag(SPV_ERROR_INVALID_ID, inst)
263            << "Merge Block may not be the block containing the OpLoopMerge\n";
264   }
265 
266   const auto continue_id = inst->GetOperandAs<uint32_t>(1);
267   const auto continue_target = _.FindDef(continue_id);
268   if (!continue_target || continue_target->opcode() != SpvOpLabel) {
269     return _.diag(SPV_ERROR_INVALID_ID, inst)
270            << "Continue Target " << _.getIdName(continue_id)
271            << " must be an OpLabel";
272   }
273 
274   if (merge_id == continue_id) {
275     return _.diag(SPV_ERROR_INVALID_ID, inst)
276            << "Merge Block and Continue Target must be different ids";
277   }
278 
279   const auto loop_control = inst->GetOperandAs<uint32_t>(2);
280   if ((loop_control >> SpvLoopControlUnrollShift) & 0x1 &&
281       (loop_control >> SpvLoopControlDontUnrollShift) & 0x1) {
282     return _.diag(SPV_ERROR_INVALID_DATA, inst)
283            << "Unroll and DontUnroll loop controls must not both be specified";
284   }
285   if ((loop_control >> SpvLoopControlDontUnrollShift) & 0x1 &&
286       (loop_control >> SpvLoopControlPeelCountShift) & 0x1) {
287     return _.diag(SPV_ERROR_INVALID_DATA, inst) << "PeelCount and DontUnroll "
288                                                    "loop controls must not "
289                                                    "both be specified";
290   }
291   if ((loop_control >> SpvLoopControlDontUnrollShift) & 0x1 &&
292       (loop_control >> SpvLoopControlPartialCountShift) & 0x1) {
293     return _.diag(SPV_ERROR_INVALID_DATA, inst) << "PartialCount and "
294                                                    "DontUnroll loop controls "
295                                                    "must not both be specified";
296   }
297 
298   uint32_t operand = 3;
299   if ((loop_control >> SpvLoopControlDependencyLengthShift) & 0x1) {
300     ++operand;
301   }
302   if ((loop_control >> SpvLoopControlMinIterationsShift) & 0x1) {
303     ++operand;
304   }
305   if ((loop_control >> SpvLoopControlMaxIterationsShift) & 0x1) {
306     ++operand;
307   }
308   if ((loop_control >> SpvLoopControlIterationMultipleShift) & 0x1) {
309     if (inst->operands().size() < operand ||
310         inst->GetOperandAs<uint32_t>(operand) == 0) {
311       return _.diag(SPV_ERROR_INVALID_DATA, inst) << "IterationMultiple loop "
312                                                      "control operand must be "
313                                                      "greater than zero";
314     }
315     ++operand;
316   }
317   if ((loop_control >> SpvLoopControlPeelCountShift) & 0x1) {
318     ++operand;
319   }
320   if ((loop_control >> SpvLoopControlPartialCountShift) & 0x1) {
321     ++operand;
322   }
323 
324   // That the right number of operands is present is checked by the parser. The
325   // above code tracks operands for expanded validation checking in the future.
326 
327   return SPV_SUCCESS;
328 }
329 
330 }  // namespace
331 
printDominatorList(const BasicBlock & b)332 void printDominatorList(const BasicBlock& b) {
333   std::cout << b.id() << " is dominated by: ";
334   const BasicBlock* bb = &b;
335   while (bb->immediate_dominator() != bb) {
336     bb = bb->immediate_dominator();
337     std::cout << bb->id() << " ";
338   }
339 }
340 
341 #define CFG_ASSERT(ASSERT_FUNC, TARGET) \
342   if (spv_result_t rcode = ASSERT_FUNC(_, TARGET)) return rcode
343 
FirstBlockAssert(ValidationState_t & _,uint32_t target)344 spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) {
345   if (_.current_function().IsFirstBlock(target)) {
346     return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(_.current_function().id()))
347            << "First block " << _.getIdName(target) << " of function "
348            << _.getIdName(_.current_function().id()) << " is targeted by block "
349            << _.getIdName(_.current_function().current_block()->id());
350   }
351   return SPV_SUCCESS;
352 }
353 
MergeBlockAssert(ValidationState_t & _,uint32_t merge_block)354 spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) {
355   if (_.current_function().IsBlockType(merge_block, kBlockTypeMerge)) {
356     return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(_.current_function().id()))
357            << "Block " << _.getIdName(merge_block)
358            << " is already a merge block for another header";
359   }
360   return SPV_SUCCESS;
361 }
362 
363 /// Update the continue construct's exit blocks once the backedge blocks are
364 /// identified in the CFG.
UpdateContinueConstructExitBlocks(Function & function,const std::vector<std::pair<uint32_t,uint32_t>> & back_edges)365 void UpdateContinueConstructExitBlocks(
366     Function& function,
367     const std::vector<std::pair<uint32_t, uint32_t>>& back_edges) {
368   auto& constructs = function.constructs();
369   // TODO(umar): Think of a faster way to do this
370   for (auto& edge : back_edges) {
371     uint32_t back_edge_block_id;
372     uint32_t loop_header_block_id;
373     std::tie(back_edge_block_id, loop_header_block_id) = edge;
374     auto is_this_header = [=](Construct& c) {
375       return c.type() == ConstructType::kLoop &&
376              c.entry_block()->id() == loop_header_block_id;
377     };
378 
379     for (auto construct : constructs) {
380       if (is_this_header(construct)) {
381         Construct* continue_construct =
382             construct.corresponding_constructs().back();
383         assert(continue_construct->type() == ConstructType::kContinue);
384 
385         BasicBlock* back_edge_block;
386         std::tie(back_edge_block, std::ignore) =
387             function.GetBlock(back_edge_block_id);
388         continue_construct->set_exit(back_edge_block);
389       }
390     }
391   }
392 }
393 
ConstructNames(ConstructType type)394 std::tuple<std::string, std::string, std::string> ConstructNames(
395     ConstructType type) {
396   std::string construct_name, header_name, exit_name;
397 
398   switch (type) {
399     case ConstructType::kSelection:
400       construct_name = "selection";
401       header_name = "selection header";
402       exit_name = "merge block";
403       break;
404     case ConstructType::kLoop:
405       construct_name = "loop";
406       header_name = "loop header";
407       exit_name = "merge block";
408       break;
409     case ConstructType::kContinue:
410       construct_name = "continue";
411       header_name = "continue target";
412       exit_name = "back-edge block";
413       break;
414     case ConstructType::kCase:
415       construct_name = "case";
416       header_name = "case entry block";
417       exit_name = "case exit block";
418       break;
419     default:
420       assert(1 == 0 && "Not defined type");
421   }
422 
423   return std::make_tuple(construct_name, header_name, exit_name);
424 }
425 
426 /// Constructs an error message for construct validation errors
ConstructErrorString(const Construct & construct,const std::string & header_string,const std::string & exit_string,const std::string & dominate_text)427 std::string ConstructErrorString(const Construct& construct,
428                                  const std::string& header_string,
429                                  const std::string& exit_string,
430                                  const std::string& dominate_text) {
431   std::string construct_name, header_name, exit_name;
432   std::tie(construct_name, header_name, exit_name) =
433       ConstructNames(construct.type());
434 
435   // TODO(umar): Add header block for continue constructs to error message
436   return "The " + construct_name + " construct with the " + header_name + " " +
437          header_string + " " + dominate_text + " the " + exit_name + " " +
438          exit_string;
439 }
440 
441 // Finds the fall through case construct of |target_block| and records it in
442 // |case_fall_through|. Returns SPV_ERROR_INVALID_CFG if the case construct
443 // headed by |target_block| branches to multiple case constructs.
FindCaseFallThrough(ValidationState_t & _,BasicBlock * target_block,uint32_t * case_fall_through,const BasicBlock * merge,const std::unordered_set<uint32_t> & case_targets,Function * function)444 spv_result_t FindCaseFallThrough(
445     ValidationState_t& _, BasicBlock* target_block, uint32_t* case_fall_through,
446     const BasicBlock* merge, const std::unordered_set<uint32_t>& case_targets,
447     Function* function) {
448   std::vector<BasicBlock*> stack;
449   stack.push_back(target_block);
450   std::unordered_set<const BasicBlock*> visited;
451   bool target_reachable = target_block->reachable();
452   int target_depth = function->GetBlockDepth(target_block);
453   while (!stack.empty()) {
454     auto block = stack.back();
455     stack.pop_back();
456 
457     if (block == merge) continue;
458 
459     if (!visited.insert(block).second) continue;
460 
461     if (target_reachable && block->reachable() &&
462         target_block->dominates(*block)) {
463       // Still in the case construct.
464       for (auto successor : *block->successors()) {
465         stack.push_back(successor);
466       }
467     } else {
468       // Exiting the case construct to non-merge block.
469       if (!case_targets.count(block->id())) {
470         int depth = function->GetBlockDepth(block);
471         if ((depth < target_depth) ||
472             (depth == target_depth && block->is_type(kBlockTypeContinue))) {
473           continue;
474         }
475 
476         return _.diag(SPV_ERROR_INVALID_CFG, target_block->label())
477                << "Case construct that targets "
478                << _.getIdName(target_block->id())
479                << " has invalid branch to block " << _.getIdName(block->id())
480                << " (not another case construct, corresponding merge, outer "
481                   "loop merge or outer loop continue)";
482       }
483 
484       if (*case_fall_through == 0u) {
485         if (target_block != block) {
486           *case_fall_through = block->id();
487         }
488       } else if (*case_fall_through != block->id()) {
489         // Case construct has at most one branch to another case construct.
490         return _.diag(SPV_ERROR_INVALID_CFG, target_block->label())
491                << "Case construct that targets "
492                << _.getIdName(target_block->id())
493                << " has branches to multiple other case construct targets "
494                << _.getIdName(*case_fall_through) << " and "
495                << _.getIdName(block->id());
496       }
497     }
498   }
499 
500   return SPV_SUCCESS;
501 }
502 
StructuredSwitchChecks(ValidationState_t & _,Function * function,const Instruction * switch_inst,const BasicBlock * header,const BasicBlock * merge)503 spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function,
504                                     const Instruction* switch_inst,
505                                     const BasicBlock* header,
506                                     const BasicBlock* merge) {
507   std::unordered_set<uint32_t> case_targets;
508   for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
509     uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
510     if (target != merge->id()) case_targets.insert(target);
511   }
512   // Tracks how many times each case construct is targeted by another case
513   // construct.
514   std::map<uint32_t, uint32_t> num_fall_through_targeted;
515   uint32_t default_case_fall_through = 0u;
516   uint32_t default_target = switch_inst->GetOperandAs<uint32_t>(1u);
517   bool default_appears_multiple_times = false;
518   for (uint32_t i = 3; i < switch_inst->operands().size(); i += 2) {
519     if (default_target == switch_inst->GetOperandAs<uint32_t>(i)) {
520       default_appears_multiple_times = true;
521       break;
522     }
523   }
524   std::unordered_map<uint32_t, uint32_t> seen_to_fall_through;
525   for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
526     uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
527     if (target == merge->id()) continue;
528 
529     uint32_t case_fall_through = 0u;
530     auto seen_iter = seen_to_fall_through.find(target);
531     if (seen_iter == seen_to_fall_through.end()) {
532       const auto target_block = function->GetBlock(target).first;
533       // OpSwitch must dominate all its case constructs.
534       if (header->reachable() && target_block->reachable() &&
535           !header->dominates(*target_block)) {
536         return _.diag(SPV_ERROR_INVALID_CFG, header->label())
537                << "Selection header " << _.getIdName(header->id())
538                << " does not dominate its case construct "
539                << _.getIdName(target);
540       }
541 
542       if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through,
543                                            merge, case_targets, function)) {
544         return error;
545       }
546 
547       // Track how many time the fall through case has been targeted.
548       if (case_fall_through != 0u) {
549         auto where = num_fall_through_targeted.lower_bound(case_fall_through);
550         if (where == num_fall_through_targeted.end() ||
551             where->first != case_fall_through) {
552           num_fall_through_targeted.insert(
553               where, std::make_pair(case_fall_through, 1));
554         } else {
555           where->second++;
556         }
557       }
558       seen_to_fall_through.insert(std::make_pair(target, case_fall_through));
559     } else {
560       case_fall_through = seen_iter->second;
561     }
562 
563     if (case_fall_through == default_target &&
564         !default_appears_multiple_times) {
565       case_fall_through = default_case_fall_through;
566     }
567     if (case_fall_through != 0u) {
568       bool is_default = i == 1;
569       if (is_default) {
570         default_case_fall_through = case_fall_through;
571       } else {
572         // Allow code like:
573         // case x:
574         // case y:
575         //   ...
576         // case z:
577         //
578         // Where x and y target the same block and fall through to z.
579         uint32_t j = i;
580         while ((j + 2 < switch_inst->operands().size()) &&
581                target == switch_inst->GetOperandAs<uint32_t>(j + 2)) {
582           j += 2;
583         }
584         // If Target T1 branches to Target T2, or if Target T1 branches to the
585         // Default target and the Default target branches to Target T2, then T1
586         // must immediately precede T2 in the list of OpSwitch Target operands.
587         if ((switch_inst->operands().size() < j + 2) ||
588             (case_fall_through != switch_inst->GetOperandAs<uint32_t>(j + 2))) {
589           return _.diag(SPV_ERROR_INVALID_CFG, switch_inst)
590                  << "Case construct that targets " << _.getIdName(target)
591                  << " has branches to the case construct that targets "
592                  << _.getIdName(case_fall_through)
593                  << ", but does not immediately precede it in the "
594                     "OpSwitch's target list";
595         }
596       }
597     }
598   }
599 
600   // Each case construct must be branched to by at most one other case
601   // construct.
602   for (const auto& pair : num_fall_through_targeted) {
603     if (pair.second > 1) {
604       return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(pair.first))
605              << "Multiple case constructs have branches to the case construct "
606                 "that targets "
607              << _.getIdName(pair.first);
608     }
609   }
610 
611   return SPV_SUCCESS;
612 }
613 
614 // Validates that all CFG divergences (i.e. conditional branch or switch) are
615 // structured correctly. Either divergence is preceded by a merge instruction
616 // or the divergence introduces at most one unseen label.
ValidateStructuredSelections(ValidationState_t & _,const std::vector<const BasicBlock * > & postorder)617 spv_result_t ValidateStructuredSelections(
618     ValidationState_t& _, const std::vector<const BasicBlock*>& postorder) {
619   std::unordered_set<uint32_t> seen;
620   for (auto iter = postorder.rbegin(); iter != postorder.rend(); ++iter) {
621     const auto* block = *iter;
622     const auto* terminator = block->terminator();
623     if (!terminator) continue;
624     const auto index = terminator - &_.ordered_instructions()[0];
625     auto* merge = &_.ordered_instructions()[index - 1];
626     // Marks merges and continues as seen.
627     if (merge->opcode() == SpvOpSelectionMerge) {
628       seen.insert(merge->GetOperandAs<uint32_t>(0));
629     } else if (merge->opcode() == SpvOpLoopMerge) {
630       seen.insert(merge->GetOperandAs<uint32_t>(0));
631       seen.insert(merge->GetOperandAs<uint32_t>(1));
632     } else {
633       // Only track the pointer if it is a merge instruction.
634       merge = nullptr;
635     }
636 
637     // Skip unreachable blocks.
638     if (!block->reachable()) continue;
639 
640     if (terminator->opcode() == SpvOpBranchConditional) {
641       const auto true_label = terminator->GetOperandAs<uint32_t>(1);
642       const auto false_label = terminator->GetOperandAs<uint32_t>(2);
643       // Mark the upcoming blocks as seen now, but only error out if this block
644       // was missing a merge instruction and both labels hadn't been seen
645       // previously.
646       const bool both_unseen =
647           seen.insert(true_label).second && seen.insert(false_label).second;
648       if (!merge && both_unseen) {
649         return _.diag(SPV_ERROR_INVALID_CFG, terminator)
650                << "Selection must be structured";
651       }
652     } else if (terminator->opcode() == SpvOpSwitch) {
653       uint32_t count = 0;
654       // Mark the targets as seen now, but only error out if this block was
655       // missing a merge instruction and there were multiple unseen labels.
656       for (uint32_t i = 1; i < terminator->operands().size(); i += 2) {
657         const auto target = terminator->GetOperandAs<uint32_t>(i);
658         if (seen.insert(target).second) {
659           count++;
660         }
661       }
662       if (!merge && count > 1) {
663         return _.diag(SPV_ERROR_INVALID_CFG, terminator)
664                << "Selection must be structured";
665       }
666     }
667   }
668 
669   return SPV_SUCCESS;
670 }
671 
StructuredControlFlowChecks(ValidationState_t & _,Function * function,const std::vector<std::pair<uint32_t,uint32_t>> & back_edges,const std::vector<const BasicBlock * > & postorder)672 spv_result_t StructuredControlFlowChecks(
673     ValidationState_t& _, Function* function,
674     const std::vector<std::pair<uint32_t, uint32_t>>& back_edges,
675     const std::vector<const BasicBlock*>& postorder) {
676   /// Check all backedges target only loop headers and have exactly one
677   /// back-edge branching to it
678 
679   // Map a loop header to blocks with back-edges to the loop header.
680   std::map<uint32_t, std::unordered_set<uint32_t>> loop_latch_blocks;
681   for (auto back_edge : back_edges) {
682     uint32_t back_edge_block;
683     uint32_t header_block;
684     std::tie(back_edge_block, header_block) = back_edge;
685     if (!function->IsBlockType(header_block, kBlockTypeLoop)) {
686       return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(back_edge_block))
687              << "Back-edges (" << _.getIdName(back_edge_block) << " -> "
688              << _.getIdName(header_block)
689              << ") can only be formed between a block and a loop header.";
690     }
691     loop_latch_blocks[header_block].insert(back_edge_block);
692   }
693 
694   // Check the loop headers have exactly one back-edge branching to it
695   for (BasicBlock* loop_header : function->ordered_blocks()) {
696     if (!loop_header->reachable()) continue;
697     if (!loop_header->is_type(kBlockTypeLoop)) continue;
698     auto loop_header_id = loop_header->id();
699     auto num_latch_blocks = loop_latch_blocks[loop_header_id].size();
700     if (num_latch_blocks != 1) {
701       return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(loop_header_id))
702              << "Loop header " << _.getIdName(loop_header_id)
703              << " is targeted by " << num_latch_blocks
704              << " back-edge blocks but the standard requires exactly one";
705     }
706   }
707 
708   // Check construct rules
709   for (const Construct& construct : function->constructs()) {
710     auto header = construct.entry_block();
711     auto merge = construct.exit_block();
712 
713     if (header->reachable() && !merge) {
714       std::string construct_name, header_name, exit_name;
715       std::tie(construct_name, header_name, exit_name) =
716           ConstructNames(construct.type());
717       return _.diag(SPV_ERROR_INTERNAL, _.FindDef(header->id()))
718              << "Construct " + construct_name + " with " + header_name + " " +
719                     _.getIdName(header->id()) + " does not have a " +
720                     exit_name + ". This may be a bug in the validator.";
721     }
722 
723     // If the exit block is reachable then it's dominated by the
724     // header.
725     if (merge && merge->reachable()) {
726       if (!header->dominates(*merge)) {
727         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(merge->id()))
728                << ConstructErrorString(construct, _.getIdName(header->id()),
729                                        _.getIdName(merge->id()),
730                                        "does not dominate");
731       }
732       // If it's really a merge block for a selection or loop, then it must be
733       // *strictly* dominated by the header.
734       if (construct.ExitBlockIsMergeBlock() && (header == merge)) {
735         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(merge->id()))
736                << ConstructErrorString(construct, _.getIdName(header->id()),
737                                        _.getIdName(merge->id()),
738                                        "does not strictly dominate");
739       }
740     }
741     // Check post-dominance for continue constructs.  But dominance and
742     // post-dominance only make sense when the construct is reachable.
743     if (header->reachable() && construct.type() == ConstructType::kContinue) {
744       if (!merge->postdominates(*header)) {
745         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(merge->id()))
746                << ConstructErrorString(construct, _.getIdName(header->id()),
747                                        _.getIdName(merge->id()),
748                                        "is not post dominated by");
749       }
750     }
751 
752     Construct::ConstructBlockSet construct_blocks = construct.blocks(function);
753     std::string construct_name, header_name, exit_name;
754     std::tie(construct_name, header_name, exit_name) =
755         ConstructNames(construct.type());
756     for (auto block : construct_blocks) {
757       // Check that all exits from the construct are via structured exits.
758       for (auto succ : *block->successors()) {
759         if (block->reachable() && !construct_blocks.count(succ) &&
760             !construct.IsStructuredExit(_, succ)) {
761           return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
762                  << "block <ID> " << _.getIdName(block->id()) << " exits the "
763                  << construct_name << " headed by <ID> "
764                  << _.getIdName(header->id())
765                  << ", but not via a structured exit";
766         }
767       }
768       if (block == header) continue;
769       // Check that for all non-header blocks, all predecessors are within this
770       // construct.
771       for (auto pred : *block->predecessors()) {
772         if (pred->reachable() && !construct_blocks.count(pred)) {
773           return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(pred->id()))
774                  << "block <ID> " << pred->id() << " branches to the "
775                  << construct_name << " construct, but not to the "
776                  << header_name << " <ID> " << header->id();
777         }
778       }
779 
780       if (block->is_type(BlockType::kBlockTypeSelection) ||
781           block->is_type(BlockType::kBlockTypeLoop)) {
782         size_t index = (block->terminator() - &_.ordered_instructions()[0]) - 1;
783         const auto& merge_inst = _.ordered_instructions()[index];
784         if (merge_inst.opcode() == SpvOpSelectionMerge ||
785             merge_inst.opcode() == SpvOpLoopMerge) {
786           uint32_t merge_id = merge_inst.GetOperandAs<uint32_t>(0);
787           auto merge_block = function->GetBlock(merge_id).first;
788           if (merge_block->reachable() &&
789               !construct_blocks.count(merge_block)) {
790             return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
791                    << "Header block " << _.getIdName(block->id())
792                    << " is contained in the " << construct_name
793                    << " construct headed by " << _.getIdName(header->id())
794                    << ", but its merge block " << _.getIdName(merge_id)
795                    << " is not";
796           }
797         }
798       }
799     }
800 
801     // Checks rules for case constructs.
802     if (construct.type() == ConstructType::kSelection &&
803         header->terminator()->opcode() == SpvOpSwitch) {
804       const auto terminator = header->terminator();
805       if (auto error =
806               StructuredSwitchChecks(_, function, terminator, header, merge)) {
807         return error;
808       }
809     }
810   }
811 
812   if (auto error = ValidateStructuredSelections(_, postorder)) {
813     return error;
814   }
815 
816   return SPV_SUCCESS;
817 }
818 
PerformWebGPUCfgChecks(ValidationState_t & _,Function * function)819 spv_result_t PerformWebGPUCfgChecks(ValidationState_t& _, Function* function) {
820   for (auto& block : function->ordered_blocks()) {
821     if (block->reachable()) continue;
822     if (block->is_type(kBlockTypeMerge)) {
823       // 1. Find the referencing merge and confirm that it is reachable.
824       BasicBlock* merge_header = function->GetMergeHeader(block);
825       assert(merge_header != nullptr);
826       if (!merge_header->reachable()) {
827         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
828                << "For WebGPU, unreachable merge-blocks must be referenced by "
829                   "a reachable merge instruction.";
830       }
831 
832       // 2. Check that the only instructions are OpLabel and OpUnreachable.
833       auto* label_inst = block->label();
834       auto* terminator_inst = block->terminator();
835       assert(label_inst != nullptr);
836       assert(terminator_inst != nullptr);
837 
838       if (terminator_inst->opcode() != SpvOpUnreachable) {
839         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
840                << "For WebGPU, unreachable merge-blocks must terminate with "
841                   "OpUnreachable.";
842       }
843 
844       auto label_idx = label_inst - &_.ordered_instructions()[0];
845       auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
846       if (label_idx + 1 != terminator_idx) {
847         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
848                << "For WebGPU, unreachable merge-blocks must only contain an "
849                   "OpLabel and OpUnreachable instruction.";
850       }
851 
852       // 3. Use label instruction to confirm there is no uses by branches.
853       for (auto use : label_inst->uses()) {
854         const auto* use_inst = use.first;
855         if (spvOpcodeIsBranch(use_inst->opcode())) {
856           return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
857                  << "For WebGPU, unreachable merge-blocks cannot be the target "
858                     "of a branch.";
859         }
860       }
861     } else if (block->is_type(kBlockTypeContinue)) {
862       // 1. Find referencing loop and confirm that it is reachable.
863       std::vector<BasicBlock*> continue_headers =
864           function->GetContinueHeaders(block);
865       if (continue_headers.empty()) {
866         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
867                << "For WebGPU, unreachable continue-target must be referenced "
868                   "by a loop instruction.";
869       }
870 
871       std::vector<BasicBlock*> reachable_headers(continue_headers.size());
872       auto iter =
873           std::copy_if(continue_headers.begin(), continue_headers.end(),
874                        reachable_headers.begin(),
875                        [](BasicBlock* header) { return header->reachable(); });
876       reachable_headers.resize(std::distance(reachable_headers.begin(), iter));
877 
878       if (reachable_headers.empty()) {
879         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
880                << "For WebGPU, unreachable continue-target must be referenced "
881                   "by a reachable loop instruction.";
882       }
883 
884       // 2. Check that the only instructions are OpLabel and OpBranch.
885       auto* label_inst = block->label();
886       auto* terminator_inst = block->terminator();
887       assert(label_inst != nullptr);
888       assert(terminator_inst != nullptr);
889 
890       if (terminator_inst->opcode() != SpvOpBranch) {
891         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
892                << "For WebGPU, unreachable continue-target must terminate with "
893                   "OpBranch.";
894       }
895 
896       auto label_idx = label_inst - &_.ordered_instructions()[0];
897       auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
898       if (label_idx + 1 != terminator_idx) {
899         return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
900                << "For WebGPU, unreachable continue-target must only contain "
901                   "an OpLabel and an OpBranch instruction.";
902       }
903 
904       // 3. Use label instruction to confirm there is no uses by branches.
905       for (auto use : label_inst->uses()) {
906         const auto* use_inst = use.first;
907         if (spvOpcodeIsBranch(use_inst->opcode())) {
908           return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
909                  << "For WebGPU, unreachable continue-target cannot be the "
910                     "target of a branch.";
911         }
912       }
913 
914       // 4. Confirm that continue-target has a back edge to a reachable loop
915       //    header block.
916       auto branch_target = terminator_inst->GetOperandAs<uint32_t>(0);
917       for (auto* continue_header : reachable_headers) {
918         if (branch_target != continue_header->id()) {
919           return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
920                  << "For WebGPU, unreachable continue-target must only have a "
921                     "back edge to a single reachable loop instruction.";
922         }
923       }
924     } else {
925       return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
926              << "For WebGPU, all blocks must be reachable, unless they are "
927              << "degenerate cases of merge-block or continue-target.";
928     }
929   }
930   return SPV_SUCCESS;
931 }
932 
PerformCfgChecks(ValidationState_t & _)933 spv_result_t PerformCfgChecks(ValidationState_t& _) {
934   for (auto& function : _.functions()) {
935     // Check all referenced blocks are defined within a function
936     if (function.undefined_block_count() != 0) {
937       std::string undef_blocks("{");
938       bool first = true;
939       for (auto undefined_block : function.undefined_blocks()) {
940         undef_blocks += _.getIdName(undefined_block);
941         if (!first) {
942           undef_blocks += " ";
943         }
944         first = false;
945       }
946       return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(function.id()))
947              << "Block(s) " << undef_blocks << "}"
948              << " are referenced but not defined in function "
949              << _.getIdName(function.id());
950     }
951 
952     // Set each block's immediate dominator and immediate postdominator,
953     // and find all back-edges.
954     //
955     // We want to analyze all the blocks in the function, even in degenerate
956     // control flow cases including unreachable blocks.  So use the augmented
957     // CFG to ensure we cover all the blocks.
958     std::vector<const BasicBlock*> postorder;
959     std::vector<const BasicBlock*> postdom_postorder;
960     std::vector<std::pair<uint32_t, uint32_t>> back_edges;
961     auto ignore_block = [](const BasicBlock*) {};
962     auto ignore_edge = [](const BasicBlock*, const BasicBlock*) {};
963     if (!function.ordered_blocks().empty()) {
964       /// calculate dominators
965       CFA<BasicBlock>::DepthFirstTraversal(
966           function.first_block(), function.AugmentedCFGSuccessorsFunction(),
967           ignore_block, [&](const BasicBlock* b) { postorder.push_back(b); },
968           ignore_edge);
969       auto edges = CFA<BasicBlock>::CalculateDominators(
970           postorder, function.AugmentedCFGPredecessorsFunction());
971       for (auto edge : edges) {
972         if (edge.first != edge.second)
973           edge.first->SetImmediateDominator(edge.second);
974       }
975 
976       /// calculate post dominators
977       CFA<BasicBlock>::DepthFirstTraversal(
978           function.pseudo_exit_block(),
979           function.AugmentedCFGPredecessorsFunction(), ignore_block,
980           [&](const BasicBlock* b) { postdom_postorder.push_back(b); },
981           ignore_edge);
982       auto postdom_edges = CFA<BasicBlock>::CalculateDominators(
983           postdom_postorder, function.AugmentedCFGSuccessorsFunction());
984       for (auto edge : postdom_edges) {
985         edge.first->SetImmediatePostDominator(edge.second);
986       }
987       /// calculate back edges.
988       CFA<BasicBlock>::DepthFirstTraversal(
989           function.pseudo_entry_block(),
990           function
991               .AugmentedCFGSuccessorsFunctionIncludingHeaderToContinueEdge(),
992           ignore_block, ignore_block,
993           [&](const BasicBlock* from, const BasicBlock* to) {
994             back_edges.emplace_back(from->id(), to->id());
995           });
996     }
997     UpdateContinueConstructExitBlocks(function, back_edges);
998 
999     auto& blocks = function.ordered_blocks();
1000     if (!blocks.empty()) {
1001       // Check if the order of blocks in the binary appear before the blocks
1002       // they dominate
1003       for (auto block = begin(blocks) + 1; block != end(blocks); ++block) {
1004         if (auto idom = (*block)->immediate_dominator()) {
1005           if (idom != function.pseudo_entry_block() &&
1006               block == std::find(begin(blocks), block, idom)) {
1007             return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(idom->id()))
1008                    << "Block " << _.getIdName((*block)->id())
1009                    << " appears in the binary before its dominator "
1010                    << _.getIdName(idom->id());
1011           }
1012         }
1013 
1014         // For WebGPU check that all unreachable blocks are degenerate cases for
1015         // merge-block or continue-target.
1016         if (spvIsWebGPUEnv(_.context()->target_env)) {
1017           spv_result_t result = PerformWebGPUCfgChecks(_, &function);
1018           if (result != SPV_SUCCESS) return result;
1019         }
1020       }
1021       // If we have structed control flow, check that no block has a control
1022       // flow nesting depth larger than the limit.
1023       if (_.HasCapability(SpvCapabilityShader)) {
1024         const int control_flow_nesting_depth_limit =
1025             _.options()->universal_limits_.max_control_flow_nesting_depth;
1026         for (auto block = begin(blocks); block != end(blocks); ++block) {
1027           if (function.GetBlockDepth(*block) >
1028               control_flow_nesting_depth_limit) {
1029             return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef((*block)->id()))
1030                    << "Maximum Control Flow nesting depth exceeded.";
1031           }
1032         }
1033       }
1034     }
1035 
1036     /// Structured control flow checks are only required for shader capabilities
1037     if (_.HasCapability(SpvCapabilityShader)) {
1038       if (auto error =
1039               StructuredControlFlowChecks(_, &function, back_edges, postorder))
1040         return error;
1041     }
1042   }
1043   return SPV_SUCCESS;
1044 }
1045 
CfgPass(ValidationState_t & _,const Instruction * inst)1046 spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
1047   SpvOp opcode = inst->opcode();
1048   switch (opcode) {
1049     case SpvOpLabel:
1050       if (auto error = _.current_function().RegisterBlock(inst->id()))
1051         return error;
1052 
1053       // TODO(github:1661) This should be done in the
1054       // ValidationState::RegisterInstruction method but because of the order of
1055       // passes the OpLabel ends up not being part of the basic block it starts.
1056       _.current_function().current_block()->set_label(inst);
1057       break;
1058     case SpvOpLoopMerge: {
1059       uint32_t merge_block = inst->GetOperandAs<uint32_t>(0);
1060       uint32_t continue_block = inst->GetOperandAs<uint32_t>(1);
1061       CFG_ASSERT(MergeBlockAssert, merge_block);
1062 
1063       if (auto error = _.current_function().RegisterLoopMerge(merge_block,
1064                                                               continue_block))
1065         return error;
1066     } break;
1067     case SpvOpSelectionMerge: {
1068       uint32_t merge_block = inst->GetOperandAs<uint32_t>(0);
1069       CFG_ASSERT(MergeBlockAssert, merge_block);
1070 
1071       if (auto error = _.current_function().RegisterSelectionMerge(merge_block))
1072         return error;
1073     } break;
1074     case SpvOpBranch: {
1075       uint32_t target = inst->GetOperandAs<uint32_t>(0);
1076       CFG_ASSERT(FirstBlockAssert, target);
1077 
1078       _.current_function().RegisterBlockEnd({target});
1079     } break;
1080     case SpvOpBranchConditional: {
1081       uint32_t tlabel = inst->GetOperandAs<uint32_t>(1);
1082       uint32_t flabel = inst->GetOperandAs<uint32_t>(2);
1083       CFG_ASSERT(FirstBlockAssert, tlabel);
1084       CFG_ASSERT(FirstBlockAssert, flabel);
1085 
1086       _.current_function().RegisterBlockEnd({tlabel, flabel});
1087     } break;
1088 
1089     case SpvOpSwitch: {
1090       std::vector<uint32_t> cases;
1091       for (size_t i = 1; i < inst->operands().size(); i += 2) {
1092         uint32_t target = inst->GetOperandAs<uint32_t>(i);
1093         CFG_ASSERT(FirstBlockAssert, target);
1094         cases.push_back(target);
1095       }
1096       _.current_function().RegisterBlockEnd({cases});
1097     } break;
1098     case SpvOpReturn: {
1099       const uint32_t return_type = _.current_function().GetResultTypeId();
1100       const Instruction* return_type_inst = _.FindDef(return_type);
1101       assert(return_type_inst);
1102       if (return_type_inst->opcode() != SpvOpTypeVoid)
1103         return _.diag(SPV_ERROR_INVALID_CFG, inst)
1104                << "OpReturn can only be called from a function with void "
1105                << "return type.";
1106       _.current_function().RegisterBlockEnd(std::vector<uint32_t>());
1107       break;
1108     }
1109     case SpvOpKill:
1110     case SpvOpReturnValue:
1111     case SpvOpUnreachable:
1112     case SpvOpTerminateInvocation:
1113       _.current_function().RegisterBlockEnd(std::vector<uint32_t>());
1114       if (opcode == SpvOpKill) {
1115         _.current_function().RegisterExecutionModelLimitation(
1116             SpvExecutionModelFragment,
1117             "OpKill requires Fragment execution model");
1118       }
1119       if (opcode == SpvOpTerminateInvocation) {
1120         _.current_function().RegisterExecutionModelLimitation(
1121             SpvExecutionModelFragment,
1122             "OpTerminateInvocation requires Fragment execution model");
1123       }
1124       break;
1125     default:
1126       break;
1127   }
1128   return SPV_SUCCESS;
1129 }
1130 
ReachabilityPass(ValidationState_t & _)1131 void ReachabilityPass(ValidationState_t& _) {
1132   for (auto& f : _.functions()) {
1133     std::vector<BasicBlock*> stack;
1134     auto entry = f.first_block();
1135     // Skip function declarations.
1136     if (entry) stack.push_back(entry);
1137 
1138     while (!stack.empty()) {
1139       auto block = stack.back();
1140       stack.pop_back();
1141 
1142       if (block->reachable()) continue;
1143 
1144       block->set_reachable(true);
1145       for (auto succ : *block->successors()) {
1146         stack.push_back(succ);
1147       }
1148     }
1149   }
1150 }
1151 
ControlFlowPass(ValidationState_t & _,const Instruction * inst)1152 spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) {
1153   switch (inst->opcode()) {
1154     case SpvOpPhi:
1155       if (auto error = ValidatePhi(_, inst)) return error;
1156       break;
1157     case SpvOpBranch:
1158       if (auto error = ValidateBranch(_, inst)) return error;
1159       break;
1160     case SpvOpBranchConditional:
1161       if (auto error = ValidateBranchConditional(_, inst)) return error;
1162       break;
1163     case SpvOpReturnValue:
1164       if (auto error = ValidateReturnValue(_, inst)) return error;
1165       break;
1166     case SpvOpSwitch:
1167       if (auto error = ValidateSwitch(_, inst)) return error;
1168       break;
1169     case SpvOpLoopMerge:
1170       if (auto error = ValidateLoopMerge(_, inst)) return error;
1171       break;
1172     default:
1173       break;
1174   }
1175 
1176   return SPV_SUCCESS;
1177 }
1178 
1179 }  // namespace val
1180 }  // namespace spvtools
1181