1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/cfg.h"
16 
17 #include <memory>
18 #include <utility>
19 
20 #include "source/cfa.h"
21 #include "source/opt/ir_builder.h"
22 #include "source/opt/ir_context.h"
23 #include "source/opt/module.h"
24 
25 namespace spvtools {
26 namespace opt {
27 namespace {
28 
29 using cbb_ptr = const opt::BasicBlock*;
30 
31 // Universal Limit of ResultID + 1
32 const int kMaxResultId = 0x400000;
33 
34 }  // namespace
35 
CFG(Module * module)36 CFG::CFG(Module* module)
37     : module_(module),
38       pseudo_entry_block_(std::unique_ptr<Instruction>(
39           new Instruction(module->context(), SpvOpLabel, 0, 0, {}))),
40       pseudo_exit_block_(std::unique_ptr<Instruction>(new Instruction(
41           module->context(), SpvOpLabel, 0, kMaxResultId, {}))) {
42   for (auto& fn : *module) {
43     for (auto& blk : fn) {
44       RegisterBlock(&blk);
45     }
46   }
47 }
48 
AddEdges(BasicBlock * blk)49 void CFG::AddEdges(BasicBlock* blk) {
50   uint32_t blk_id = blk->id();
51   // Force the creation of an entry, not all basic block have predecessors
52   // (such as the entry blocks and some unreachables).
53   label2preds_[blk_id];
54   const auto* const_blk = blk;
55   const_blk->ForEachSuccessorLabel(
56       [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); });
57 }
58 
RemoveNonExistingEdges(uint32_t blk_id)59 void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
60   std::vector<uint32_t> updated_pred_list;
61   for (uint32_t id : preds(blk_id)) {
62     const BasicBlock* pred_blk = block(id);
63     bool has_branch = false;
64     pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
65       if (succ == blk_id) {
66         has_branch = true;
67       }
68     });
69     if (has_branch) updated_pred_list.push_back(id);
70   }
71 
72   label2preds_.at(blk_id) = std::move(updated_pred_list);
73 }
74 
ComputeStructuredOrder(Function * func,BasicBlock * root,std::list<BasicBlock * > * order)75 void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
76                                  std::list<BasicBlock*>* order) {
77   assert(module_->context()->get_feature_mgr()->HasCapability(
78              SpvCapabilityShader) &&
79          "This only works on structured control flow");
80 
81   // Compute structured successors and do DFS.
82   ComputeStructuredSuccessors(func);
83   auto ignore_block = [](cbb_ptr) {};
84   auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
85   auto get_structured_successors = [this](const BasicBlock* b) {
86     return &(block2structured_succs_[b]);
87   };
88 
89   // TODO(greg-lunarg): Get rid of const_cast by making moving const
90   // out of the cfa.h prototypes and into the invoking code.
91   auto post_order = [&](cbb_ptr b) {
92     order->push_front(const_cast<BasicBlock*>(b));
93   };
94   CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
95                                        ignore_block, post_order, ignore_edge);
96 }
97 
ForEachBlockInPostOrder(BasicBlock * bb,const std::function<void (BasicBlock *)> & f)98 void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
99                                   const std::function<void(BasicBlock*)>& f) {
100   std::vector<BasicBlock*> po;
101   std::unordered_set<BasicBlock*> seen;
102   ComputePostOrderTraversal(bb, &po, &seen);
103 
104   for (BasicBlock* current_bb : po) {
105     if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) {
106       f(current_bb);
107     }
108   }
109 }
110 
ForEachBlockInReversePostOrder(BasicBlock * bb,const std::function<void (BasicBlock *)> & f)111 void CFG::ForEachBlockInReversePostOrder(
112     BasicBlock* bb, const std::function<void(BasicBlock*)>& f) {
113   WhileEachBlockInReversePostOrder(bb, [f](BasicBlock* b) {
114     f(b);
115     return true;
116   });
117 }
118 
WhileEachBlockInReversePostOrder(BasicBlock * bb,const std::function<bool (BasicBlock *)> & f)119 bool CFG::WhileEachBlockInReversePostOrder(
120     BasicBlock* bb, const std::function<bool(BasicBlock*)>& f) {
121   std::vector<BasicBlock*> po;
122   std::unordered_set<BasicBlock*> seen;
123   ComputePostOrderTraversal(bb, &po, &seen);
124 
125   for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) {
126     if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) {
127       if (!f(*current_bb)) {
128         return false;
129       }
130     }
131   }
132   return true;
133 }
134 
ComputeStructuredSuccessors(Function * func)135 void CFG::ComputeStructuredSuccessors(Function* func) {
136   block2structured_succs_.clear();
137   for (auto& blk : *func) {
138     // If no predecessors in function, make successor to pseudo entry.
139     if (label2preds_[blk.id()].size() == 0)
140       block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
141 
142     // If header, make merge block first successor and continue block second
143     // successor if there is one.
144     uint32_t mbid = blk.MergeBlockIdIfAny();
145     if (mbid != 0) {
146       block2structured_succs_[&blk].push_back(block(mbid));
147       uint32_t cbid = blk.ContinueBlockIdIfAny();
148       if (cbid != 0) {
149         block2structured_succs_[&blk].push_back(block(cbid));
150       }
151     }
152 
153     // Add true successors.
154     const auto& const_blk = blk;
155     const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) {
156       block2structured_succs_[&blk].push_back(block(sbid));
157     });
158   }
159 }
160 
ComputePostOrderTraversal(BasicBlock * bb,std::vector<BasicBlock * > * order,std::unordered_set<BasicBlock * > * seen)161 void CFG::ComputePostOrderTraversal(BasicBlock* bb,
162                                     std::vector<BasicBlock*>* order,
163                                     std::unordered_set<BasicBlock*>* seen) {
164   std::vector<BasicBlock*> stack;
165   stack.push_back(bb);
166   while (!stack.empty()) {
167     bb = stack.back();
168     seen->insert(bb);
169     static_cast<const BasicBlock*>(bb)->WhileEachSuccessorLabel(
170         [&seen, &stack, this](const uint32_t sbid) {
171           BasicBlock* succ_bb = id2block_[sbid];
172           if (!seen->count(succ_bb)) {
173             stack.push_back(succ_bb);
174             return false;
175           }
176           return true;
177         });
178     if (stack.back() == bb) {
179       order->push_back(bb);
180       stack.pop_back();
181     }
182   }
183 }
184 
SplitLoopHeader(BasicBlock * bb)185 BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) {
186   assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop.");
187 
188   Function* fn = bb->GetParent();
189   IRContext* context = module_->context();
190 
191   // Get the new header id up front.  If we are out of ids, then we cannot split
192   // the loop.
193   uint32_t new_header_id = context->TakeNextId();
194   if (new_header_id == 0) {
195     return nullptr;
196   }
197 
198   // Find the insertion point for the new bb.
199   Function::iterator header_it = std::find_if(
200       fn->begin(), fn->end(),
201       [bb](BasicBlock& block_in_func) { return &block_in_func == bb; });
202   assert(header_it != fn->end());
203 
204   const std::vector<uint32_t>& pred = preds(bb->id());
205   // Find the back edge
206   BasicBlock* latch_block = nullptr;
207   Function::iterator latch_block_iter = header_it;
208   while (++latch_block_iter != fn->end()) {
209     // If blocks are in the proper order, then the only branch that appears
210     // after the header is the latch.
211     if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
212         pred.end()) {
213       break;
214     }
215   }
216   assert(latch_block_iter != fn->end() && "Could not find the latch.");
217   latch_block = &*latch_block_iter;
218 
219   RemoveSuccessorEdges(bb);
220 
221   // Create the new header bb basic bb.
222   // Leave the phi instructions behind.
223   auto iter = bb->begin();
224   while (iter->opcode() == SpvOpPhi) {
225     ++iter;
226   }
227 
228   BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter);
229   context->AnalyzeDefUse(new_header->GetLabelInst());
230 
231   // Update cfg
232   RegisterBlock(new_header);
233 
234   // Update bb mappings.
235   context->set_instr_block(new_header->GetLabelInst(), new_header);
236   new_header->ForEachInst([new_header, context](Instruction* inst) {
237     context->set_instr_block(inst, new_header);
238   });
239 
240   // Adjust the OpPhi instructions as needed.
241   bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
242     std::vector<uint32_t> preheader_phi_ops;
243     std::vector<Operand> header_phi_ops;
244 
245     // Identify where the original inputs to original OpPhi belong: header or
246     // preheader.
247     for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
248       uint32_t def_id = phi->GetSingleWordInOperand(i);
249       uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
250       if (branch_id == latch_block->id()) {
251         header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}});
252         header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}});
253       } else {
254         preheader_phi_ops.push_back(def_id);
255         preheader_phi_ops.push_back(branch_id);
256       }
257     }
258 
259     // Create a phi instruction if and only if the preheader_phi_ops has more
260     // than one pair.
261     if (preheader_phi_ops.size() > 2) {
262       InstructionBuilder builder(
263           context, &*bb->begin(),
264           IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
265 
266       Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops);
267 
268       // Add the OpPhi to the header bb.
269       header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}});
270       header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
271     } else {
272       // An OpPhi with a single entry is just a copy.  In this case use the same
273       // instruction in the new header.
274       header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}});
275       header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
276     }
277 
278     phi->RemoveFromList();
279     std::unique_ptr<Instruction> phi_owner(phi);
280     phi->SetInOperands(std::move(header_phi_ops));
281     new_header->begin()->InsertBefore(std::move(phi_owner));
282     context->set_instr_block(phi, new_header);
283     context->AnalyzeUses(phi);
284   });
285 
286   // Add a branch to the new header.
287   InstructionBuilder branch_builder(
288       context, bb,
289       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
290   bb->AddInstruction(
291       MakeUnique<Instruction>(context, SpvOpBranch, 0, 0,
292                               std::initializer_list<Operand>{
293                                   {SPV_OPERAND_TYPE_ID, {new_header->id()}}}));
294   context->AnalyzeUses(bb->terminator());
295   context->set_instr_block(bb->terminator(), bb);
296   label2preds_[new_header->id()].push_back(bb->id());
297 
298   // Update the latch to branch to the new header.
299   latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) {
300     if (*id == bb->id()) {
301       *id = new_header_id;
302     }
303   });
304   Instruction* latch_branch = latch_block->terminator();
305   context->AnalyzeUses(latch_branch);
306   label2preds_[new_header->id()].push_back(latch_block->id());
307 
308   auto& block_preds = label2preds_[bb->id()];
309   auto latch_pos =
310       std::find(block_preds.begin(), block_preds.end(), latch_block->id());
311   assert(latch_pos != block_preds.end() && "The cfg was invalid.");
312   block_preds.erase(latch_pos);
313 
314   // Update the loop descriptors
315   if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) {
316     LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent());
317     Loop* loop = (*loop_desc)[bb->id()];
318 
319     loop->AddBasicBlock(new_header_id);
320     loop->SetHeaderBlock(new_header);
321     loop_desc->SetBasicBlockToLoop(new_header_id, loop);
322 
323     loop->RemoveBasicBlock(bb->id());
324     loop->SetPreHeaderBlock(bb);
325 
326     Loop* parent_loop = loop->GetParent();
327     if (parent_loop != nullptr) {
328       parent_loop->AddBasicBlock(bb->id());
329       loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop);
330     } else {
331       loop_desc->SetBasicBlockToLoop(bb->id(), nullptr);
332     }
333   }
334   return new_header;
335 }
336 
337 }  // namespace opt
338 }  // namespace spvtools
339