1 // Copyright (c) 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "transformation_add_loop_preheader.h"
16 
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/opt/instruction.h"
19 
20 namespace spvtools {
21 namespace fuzz {
TransformationAddLoopPreheader(protobufs::TransformationAddLoopPreheader message)22 TransformationAddLoopPreheader::TransformationAddLoopPreheader(
23     protobufs::TransformationAddLoopPreheader message)
24     : message_(std::move(message)) {}
25 
TransformationAddLoopPreheader(uint32_t loop_header_block,uint32_t fresh_id,std::vector<uint32_t> phi_id)26 TransformationAddLoopPreheader::TransformationAddLoopPreheader(
27     uint32_t loop_header_block, uint32_t fresh_id,
28     std::vector<uint32_t> phi_id) {
29   message_.set_loop_header_block(loop_header_block);
30   message_.set_fresh_id(fresh_id);
31   for (auto id : phi_id) {
32     message_.add_phi_id(id);
33   }
34 }
35 
IsApplicable(opt::IRContext * ir_context,const TransformationContext &) const36 bool TransformationAddLoopPreheader::IsApplicable(
37     opt::IRContext* ir_context,
38     const TransformationContext& /* unused */) const {
39   // |message_.loop_header_block()| must be the id of a loop header block.
40   opt::BasicBlock* loop_header_block =
41       fuzzerutil::MaybeFindBlock(ir_context, message_.loop_header_block());
42   if (!loop_header_block || !loop_header_block->IsLoopHeader()) {
43     return false;
44   }
45 
46   // The id for the preheader must actually be fresh.
47   std::set<uint32_t> used_ids;
48   if (!CheckIdIsFreshAndNotUsedByThisTransformation(message_.fresh_id(),
49                                                     ir_context, &used_ids)) {
50     return false;
51   }
52 
53   size_t num_predecessors =
54       ir_context->cfg()->preds(message_.loop_header_block()).size();
55 
56   // The block must have at least 2 predecessors (the back-edge block and
57   // another predecessor outside of the loop)
58   if (num_predecessors < 2) {
59     return false;
60   }
61 
62   // If the block only has one predecessor outside of the loop (and thus 2 in
63   // total), then no additional fresh ids are necessary.
64   if (num_predecessors == 2) {
65     return true;
66   }
67 
68   // Count the number of OpPhi instructions.
69   int32_t num_phi_insts = 0;
70   loop_header_block->ForEachPhiInst(
71       [&num_phi_insts](opt::Instruction* /* unused */) { num_phi_insts++; });
72 
73   // There must be enough fresh ids for the OpPhi instructions.
74   if (num_phi_insts > message_.phi_id_size()) {
75     return false;
76   }
77 
78   // Check that the needed ids are fresh and distinct.
79   for (int32_t i = 0; i < num_phi_insts; i++) {
80     if (!CheckIdIsFreshAndNotUsedByThisTransformation(message_.phi_id(i),
81                                                       ir_context, &used_ids)) {
82       return false;
83     }
84   }
85 
86   return true;
87 }
88 
Apply(opt::IRContext * ir_context,TransformationContext *) const89 void TransformationAddLoopPreheader::Apply(
90     opt::IRContext* ir_context,
91     TransformationContext* /* transformation_context */) const {
92   // Find the loop header.
93   opt::BasicBlock* loop_header =
94       fuzzerutil::MaybeFindBlock(ir_context, message_.loop_header_block());
95 
96   auto dominator_analysis =
97       ir_context->GetDominatorAnalysis(loop_header->GetParent());
98 
99   uint32_t back_edge_block_id = 0;
100 
101   // Update the branching instructions of the out-of-loop predecessors of the
102   // header. Set |back_edge_block_id| to be the id of the back-edge block.
103   ir_context->get_def_use_mgr()->ForEachUse(
104       loop_header->id(),
105       [this, &ir_context, &dominator_analysis, &loop_header,
106        &back_edge_block_id](opt::Instruction* use_inst, uint32_t use_index) {
107         if (dominator_analysis->Dominates(loop_header->GetLabelInst(),
108                                           use_inst)) {
109           // If |use_inst| is a branch instruction dominated by the header, the
110           // block containing it is the back-edge block.
111           if (use_inst->IsBranch()) {
112             assert(back_edge_block_id == 0 &&
113                    "There should only be one back-edge block");
114             back_edge_block_id = ir_context->get_instr_block(use_inst)->id();
115           }
116           // References to the header inside the loop should not be updated
117           return;
118         }
119 
120         // If |use_inst| is not a branch or merge instruction, it should not be
121         // changed.
122         if (!use_inst->IsBranch() &&
123             use_inst->opcode() != SpvOpSelectionMerge &&
124             use_inst->opcode() != SpvOpLoopMerge) {
125           return;
126         }
127 
128         // Update the reference.
129         use_inst->SetOperand(use_index, {message_.fresh_id()});
130       });
131 
132   assert(back_edge_block_id && "The back-edge block should have been found");
133 
134   // Make a new block for the preheader.
135   std::unique_ptr<opt::BasicBlock> preheader = MakeUnique<opt::BasicBlock>(
136       std::unique_ptr<opt::Instruction>(new opt::Instruction(
137           ir_context, SpvOpLabel, 0, message_.fresh_id(), {})));
138 
139   uint32_t phi_ids_used = 0;
140 
141   // Update the OpPhi instructions and, if there is more than one out-of-loop
142   // predecessor, add necessary OpPhi instructions so the preheader.
143   loop_header->ForEachPhiInst([this, &ir_context, &preheader,
144                                &back_edge_block_id,
145                                &phi_ids_used](opt::Instruction* phi_inst) {
146     // The loop header must have at least 2 incoming edges (the back edge, and
147     // at least one from outside the loop).
148     assert(phi_inst->NumInOperands() >= 4);
149 
150     if (phi_inst->NumInOperands() == 4) {
151       // There is just one out-of-loop predecessor, so no additional
152       // instructions in the preheader are necessary. The reference to the
153       // original out-of-loop predecessor needs to be updated so that it refers
154       // to the preheader.
155       uint32_t index_of_out_of_loop_pred_id =
156           phi_inst->GetInOperand(1).words[0] == back_edge_block_id ? 3 : 1;
157       phi_inst->SetInOperand(index_of_out_of_loop_pred_id, {preheader->id()});
158     } else {
159       // There is more than one out-of-loop predecessor, so an OpPhi instruction
160       // needs to be added to the preheader, and its value will depend on all
161       // the current out-of-loop predecessors of the header.
162 
163       // Get the operand list and the value corresponding to the back-edge
164       // block.
165       std::vector<opt::Operand> preheader_in_operands;
166       uint32_t back_edge_val = 0;
167 
168       for (uint32_t i = 0; i < phi_inst->NumInOperands(); i += 2) {
169         // Only add operands if they don't refer to the back-edge block.
170         if (phi_inst->GetInOperand(i + 1).words[0] == back_edge_block_id) {
171           back_edge_val = phi_inst->GetInOperand(i).words[0];
172         } else {
173           preheader_in_operands.push_back(std::move(phi_inst->GetInOperand(i)));
174           preheader_in_operands.push_back(
175               std::move(phi_inst->GetInOperand(i + 1)));
176         }
177       }
178 
179       // Add the new instruction to the preheader.
180       uint32_t fresh_phi_id = message_.phi_id(phi_ids_used++);
181 
182       // Update id bound.
183       fuzzerutil::UpdateModuleIdBound(ir_context, fresh_phi_id);
184 
185       preheader->AddInstruction(std::unique_ptr<opt::Instruction>(
186           new opt::Instruction(ir_context, SpvOpPhi, phi_inst->type_id(),
187                                fresh_phi_id, preheader_in_operands)));
188 
189       // Update the OpPhi instruction in the header so that it refers to the
190       // back edge block and the preheader as the predecessors, and it uses the
191       // newly-defined OpPhi in the preheader for the corresponding value.
192       phi_inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {fresh_phi_id}},
193                                {SPV_OPERAND_TYPE_ID, {preheader->id()}},
194                                {SPV_OPERAND_TYPE_ID, {back_edge_val}},
195                                {SPV_OPERAND_TYPE_ID, {back_edge_block_id}}});
196     }
197   });
198 
199   // Update id bound.
200   fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id());
201 
202   // Add an unconditional branch from the preheader to the header.
203   preheader->AddInstruction(
204       std::unique_ptr<opt::Instruction>(new opt::Instruction(
205           ir_context, SpvOpBranch, 0, 0,
206           std::initializer_list<opt::Operand>{opt::Operand(
207               spv_operand_type_t::SPV_OPERAND_TYPE_ID, {loop_header->id()})})));
208 
209   // Insert the preheader in the module.
210   loop_header->GetParent()->InsertBasicBlockBefore(std::move(preheader),
211                                                    loop_header);
212 
213   // Invalidate analyses because the structure of the program changed.
214   ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
215 }
216 
ToMessage() const217 protobufs::Transformation TransformationAddLoopPreheader::ToMessage() const {
218   protobufs::Transformation result;
219   *result.mutable_add_loop_preheader() = message_;
220   return result;
221 }
222 
GetFreshIds() const223 std::unordered_set<uint32_t> TransformationAddLoopPreheader::GetFreshIds()
224     const {
225   std::unordered_set<uint32_t> result = {message_.fresh_id()};
226   for (auto id : message_.phi_id()) {
227     result.insert(id);
228   }
229   return result;
230 }
231 
232 }  // namespace fuzz
233 }  // namespace spvtools
234