1 // Copyright (c) 2020 Vasyl Teliman
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_propagate_instruction_up.h"
16 
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_descriptor.h"
19 
20 namespace spvtools {
21 namespace fuzz {
22 namespace {
23 
GetResultIdFromLabelId(const opt::Instruction & phi_inst,uint32_t label_id)24 uint32_t GetResultIdFromLabelId(const opt::Instruction& phi_inst,
25                                 uint32_t label_id) {
26   assert(phi_inst.opcode() == SpvOpPhi && "|phi_inst| is not an OpPhi");
27 
28   for (uint32_t i = 1; i < phi_inst.NumInOperands(); i += 2) {
29     if (phi_inst.GetSingleWordInOperand(i) == label_id) {
30       return phi_inst.GetSingleWordInOperand(i - 1);
31     }
32   }
33 
34   return 0;
35 }
36 
ContainsPointers(const opt::analysis::Type & type)37 bool ContainsPointers(const opt::analysis::Type& type) {
38   switch (type.kind()) {
39     case opt::analysis::Type::kPointer:
40       return true;
41     case opt::analysis::Type::kStruct:
42       return std::any_of(type.AsStruct()->element_types().begin(),
43                          type.AsStruct()->element_types().end(),
44                          [](const opt::analysis::Type* element_type) {
45                            return ContainsPointers(*element_type);
46                          });
47     default:
48       return false;
49   }
50 }
51 
HasValidDependencies(opt::IRContext * ir_context,opt::Instruction * inst)52 bool HasValidDependencies(opt::IRContext* ir_context, opt::Instruction* inst) {
53   const auto* inst_block = ir_context->get_instr_block(inst);
54   assert(inst_block &&
55          "This function shouldn't be applied to global instructions or function"
56          "parameters");
57 
58   for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
59     const auto& operand = inst->GetInOperand(i);
60     if (operand.type != SPV_OPERAND_TYPE_ID) {
61       // Consider only <id> operands.
62       continue;
63     }
64 
65     auto* dependency = ir_context->get_def_use_mgr()->GetDef(operand.words[0]);
66     assert(dependency && "Operand has invalid id");
67 
68     if (ir_context->get_instr_block(dependency) == inst_block &&
69         dependency->opcode() != SpvOpPhi) {
70       // |dependency| is "valid" if it's an OpPhi from the same basic block or
71       // an instruction from a different basic block.
72       return false;
73     }
74   }
75 
76   return true;
77 }
78 
79 }  // namespace
80 
TransformationPropagateInstructionUp(protobufs::TransformationPropagateInstructionUp message)81 TransformationPropagateInstructionUp::TransformationPropagateInstructionUp(
82     protobufs::TransformationPropagateInstructionUp message)
83     : message_(std::move(message)) {}
84 
TransformationPropagateInstructionUp(uint32_t block_id,const std::map<uint32_t,uint32_t> & predecessor_id_to_fresh_id)85 TransformationPropagateInstructionUp::TransformationPropagateInstructionUp(
86     uint32_t block_id,
87     const std::map<uint32_t, uint32_t>& predecessor_id_to_fresh_id) {
88   message_.set_block_id(block_id);
89   *message_.mutable_predecessor_id_to_fresh_id() =
90       fuzzerutil::MapToRepeatedUInt32Pair(predecessor_id_to_fresh_id);
91 }
92 
IsApplicable(opt::IRContext * ir_context,const TransformationContext &) const93 bool TransformationPropagateInstructionUp::IsApplicable(
94     opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
95   // Check that we can apply this transformation to the |block_id|.
96   if (!IsApplicableToBlock(ir_context, message_.block_id())) {
97     return false;
98   }
99 
100   const auto predecessor_id_to_fresh_id = fuzzerutil::RepeatedUInt32PairToMap(
101       message_.predecessor_id_to_fresh_id());
102   for (auto id : ir_context->cfg()->preds(message_.block_id())) {
103     // Each predecessor must have a fresh id in the |predecessor_id_to_fresh_id|
104     // map.
105     if (!predecessor_id_to_fresh_id.count(id)) {
106       return false;
107     }
108   }
109 
110   std::vector<uint32_t> maybe_fresh_ids;
111   maybe_fresh_ids.reserve(predecessor_id_to_fresh_id.size());
112   for (const auto& entry : predecessor_id_to_fresh_id) {
113     maybe_fresh_ids.push_back(entry.second);
114   }
115 
116   // All ids must be unique and fresh.
117   return !fuzzerutil::HasDuplicates(maybe_fresh_ids) &&
118          std::all_of(maybe_fresh_ids.begin(), maybe_fresh_ids.end(),
119                      [ir_context](uint32_t id) {
120                        return fuzzerutil::IsFreshId(ir_context, id);
121                      });
122 }
123 
Apply(opt::IRContext * ir_context,TransformationContext *) const124 void TransformationPropagateInstructionUp::Apply(
125     opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
126   auto* inst = GetInstructionToPropagate(ir_context, message_.block_id());
127   assert(inst &&
128          "The block must have at least one supported instruction to propagate");
129   assert(inst->result_id() && inst->type_id() &&
130          "|inst| must have a result id and a type id");
131 
132   opt::Instruction::OperandList op_phi_operands;
133   const auto predecessor_id_to_fresh_id = fuzzerutil::RepeatedUInt32PairToMap(
134       message_.predecessor_id_to_fresh_id());
135   std::unordered_set<uint32_t> visited_predecessors;
136   for (auto predecessor_id : ir_context->cfg()->preds(message_.block_id())) {
137     // A block can have multiple identical predecessors.
138     if (visited_predecessors.count(predecessor_id)) {
139       continue;
140     }
141 
142     visited_predecessors.insert(predecessor_id);
143 
144     auto new_result_id = predecessor_id_to_fresh_id.at(predecessor_id);
145 
146     // Compute InOperands for the OpPhi instruction to be inserted later.
147     op_phi_operands.push_back({SPV_OPERAND_TYPE_ID, {new_result_id}});
148     op_phi_operands.push_back({SPV_OPERAND_TYPE_ID, {predecessor_id}});
149 
150     // Create a clone of the |inst| to be inserted into the |predecessor_id|.
151     std::unique_ptr<opt::Instruction> clone(inst->Clone(ir_context));
152     clone->SetResultId(new_result_id);
153 
154     fuzzerutil::UpdateModuleIdBound(ir_context, new_result_id);
155 
156     // Adjust |clone|'s operands to account for possible dependencies on OpPhi
157     // instructions from the same basic block.
158     for (uint32_t i = 0; i < clone->NumInOperands(); ++i) {
159       auto& operand = clone->GetInOperand(i);
160       if (operand.type != SPV_OPERAND_TYPE_ID) {
161         // Consider only ids.
162         continue;
163       }
164 
165       const auto* dependency_inst =
166           ir_context->get_def_use_mgr()->GetDef(operand.words[0]);
167       assert(dependency_inst && "|clone| depends on an invalid id");
168 
169       if (ir_context->get_instr_block(dependency_inst->result_id()) !=
170           ir_context->cfg()->block(message_.block_id())) {
171         // We don't need to adjust anything if |dependency_inst| is from a
172         // different block, a global instruction or a function parameter.
173         continue;
174       }
175 
176       assert(dependency_inst->opcode() == SpvOpPhi &&
177              "Propagated instruction can depend only on OpPhis from the same "
178              "basic block or instructions from different basic blocks");
179 
180       auto new_id = GetResultIdFromLabelId(*dependency_inst, predecessor_id);
181       assert(new_id && "OpPhi instruction is missing a predecessor");
182       operand.words[0] = new_id;
183     }
184 
185     auto* insert_before_inst = fuzzerutil::GetLastInsertBeforeInstruction(
186         ir_context, predecessor_id, clone->opcode());
187     assert(insert_before_inst && "Can't insert |clone| into |predecessor_id");
188 
189     insert_before_inst->InsertBefore(std::move(clone));
190   }
191 
192   // Insert an OpPhi instruction into the basic block of |inst|.
193   ir_context->get_instr_block(inst)->begin()->InsertBefore(
194       MakeUnique<opt::Instruction>(ir_context, SpvOpPhi, inst->type_id(),
195                                    inst->result_id(),
196                                    std::move(op_phi_operands)));
197 
198   // Remove |inst| from the basic block.
199   ir_context->KillInst(inst);
200 
201   // We have changed the module so most analyzes are now invalid.
202   ir_context->InvalidateAnalysesExceptFor(
203       opt::IRContext::Analysis::kAnalysisNone);
204 }
205 
ToMessage() const206 protobufs::Transformation TransformationPropagateInstructionUp::ToMessage()
207     const {
208   protobufs::Transformation result;
209   *result.mutable_propagate_instruction_up() = message_;
210   return result;
211 }
212 
IsOpcodeSupported(SpvOp opcode)213 bool TransformationPropagateInstructionUp::IsOpcodeSupported(SpvOp opcode) {
214   // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3605):
215   //  We only support "simple" instructions that don't work with memory.
216   //  We should extend this so that we support the ones that modify the memory
217   //  too.
218   switch (opcode) {
219     case SpvOpUndef:
220     case SpvOpAccessChain:
221     case SpvOpInBoundsAccessChain:
222     case SpvOpArrayLength:
223     case SpvOpVectorExtractDynamic:
224     case SpvOpVectorInsertDynamic:
225     case SpvOpVectorShuffle:
226     case SpvOpCompositeConstruct:
227     case SpvOpCompositeExtract:
228     case SpvOpCompositeInsert:
229     case SpvOpCopyObject:
230     case SpvOpTranspose:
231     case SpvOpConvertFToU:
232     case SpvOpConvertFToS:
233     case SpvOpConvertSToF:
234     case SpvOpConvertUToF:
235     case SpvOpUConvert:
236     case SpvOpSConvert:
237     case SpvOpFConvert:
238     case SpvOpQuantizeToF16:
239     case SpvOpSatConvertSToU:
240     case SpvOpSatConvertUToS:
241     case SpvOpBitcast:
242     case SpvOpSNegate:
243     case SpvOpFNegate:
244     case SpvOpIAdd:
245     case SpvOpFAdd:
246     case SpvOpISub:
247     case SpvOpFSub:
248     case SpvOpIMul:
249     case SpvOpFMul:
250     case SpvOpUDiv:
251     case SpvOpSDiv:
252     case SpvOpFDiv:
253     case SpvOpUMod:
254     case SpvOpSRem:
255     case SpvOpSMod:
256     case SpvOpFRem:
257     case SpvOpFMod:
258     case SpvOpVectorTimesScalar:
259     case SpvOpMatrixTimesScalar:
260     case SpvOpVectorTimesMatrix:
261     case SpvOpMatrixTimesVector:
262     case SpvOpMatrixTimesMatrix:
263     case SpvOpOuterProduct:
264     case SpvOpDot:
265     case SpvOpIAddCarry:
266     case SpvOpISubBorrow:
267     case SpvOpUMulExtended:
268     case SpvOpSMulExtended:
269     case SpvOpAny:
270     case SpvOpAll:
271     case SpvOpIsNan:
272     case SpvOpIsInf:
273     case SpvOpIsFinite:
274     case SpvOpIsNormal:
275     case SpvOpSignBitSet:
276     case SpvOpLessOrGreater:
277     case SpvOpOrdered:
278     case SpvOpUnordered:
279     case SpvOpLogicalEqual:
280     case SpvOpLogicalNotEqual:
281     case SpvOpLogicalOr:
282     case SpvOpLogicalAnd:
283     case SpvOpLogicalNot:
284     case SpvOpSelect:
285     case SpvOpIEqual:
286     case SpvOpINotEqual:
287     case SpvOpUGreaterThan:
288     case SpvOpSGreaterThan:
289     case SpvOpUGreaterThanEqual:
290     case SpvOpSGreaterThanEqual:
291     case SpvOpULessThan:
292     case SpvOpSLessThan:
293     case SpvOpULessThanEqual:
294     case SpvOpSLessThanEqual:
295     case SpvOpFOrdEqual:
296     case SpvOpFUnordEqual:
297     case SpvOpFOrdNotEqual:
298     case SpvOpFUnordNotEqual:
299     case SpvOpFOrdLessThan:
300     case SpvOpFUnordLessThan:
301     case SpvOpFOrdGreaterThan:
302     case SpvOpFUnordGreaterThan:
303     case SpvOpFOrdLessThanEqual:
304     case SpvOpFUnordLessThanEqual:
305     case SpvOpFOrdGreaterThanEqual:
306     case SpvOpFUnordGreaterThanEqual:
307     case SpvOpShiftRightLogical:
308     case SpvOpShiftRightArithmetic:
309     case SpvOpShiftLeftLogical:
310     case SpvOpBitwiseOr:
311     case SpvOpBitwiseXor:
312     case SpvOpBitwiseAnd:
313     case SpvOpNot:
314     case SpvOpBitFieldInsert:
315     case SpvOpBitFieldSExtract:
316     case SpvOpBitFieldUExtract:
317     case SpvOpBitReverse:
318     case SpvOpBitCount:
319     case SpvOpCopyLogical:
320     case SpvOpPtrEqual:
321     case SpvOpPtrNotEqual:
322       return true;
323     default:
324       return false;
325   }
326 }
327 
328 opt::Instruction*
GetInstructionToPropagate(opt::IRContext * ir_context,uint32_t block_id)329 TransformationPropagateInstructionUp::GetInstructionToPropagate(
330     opt::IRContext* ir_context, uint32_t block_id) {
331   auto* block = ir_context->cfg()->block(block_id);
332   assert(block && "|block_id| is invalid");
333 
334   for (auto& inst : *block) {
335     // We look for the first instruction in the block that satisfies the
336     // following rules:
337     // - it's not an OpPhi
338     // - it must be supported by this transformation
339     // - it may depend only on instructions from different basic blocks or on
340     //   OpPhi instructions from the same basic block.
341     if (inst.opcode() == SpvOpPhi || !IsOpcodeSupported(inst.opcode()) ||
342         !inst.type_id() || !inst.result_id()) {
343       continue;
344     }
345 
346     const auto* inst_type = ir_context->get_type_mgr()->GetType(inst.type_id());
347     assert(inst_type && "|inst| has invalid type");
348 
349     if (inst_type->AsSampledImage()) {
350       // OpTypeSampledImage cannot be used as an argument to OpPhi instructions,
351       // thus we cannot support this type.
352       continue;
353     }
354 
355     if (!ir_context->get_feature_mgr()->HasCapability(
356             SpvCapabilityVariablePointersStorageBuffer) &&
357         ContainsPointers(*inst_type)) {
358       // OpPhi supports pointer operands only with VariablePointers or
359       // VariablePointersStorageBuffer capabilities.
360       //
361       // Note that VariablePointers capability implicitly declares
362       // VariablePointersStorageBuffer capability.
363       continue;
364     }
365 
366     if (!HasValidDependencies(ir_context, &inst)) {
367       continue;
368     }
369 
370     return &inst;
371   }
372 
373   return nullptr;
374 }
375 
IsApplicableToBlock(opt::IRContext * ir_context,uint32_t block_id)376 bool TransformationPropagateInstructionUp::IsApplicableToBlock(
377     opt::IRContext* ir_context, uint32_t block_id) {
378   // Check that |block_id| is valid.
379   const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
380   if (!label_inst || label_inst->opcode() != SpvOpLabel) {
381     return false;
382   }
383 
384   // Check that |block| has predecessors.
385   const auto& predecessors = ir_context->cfg()->preds(block_id);
386   if (predecessors.empty()) {
387     return false;
388   }
389 
390   // The block must contain an instruction to propagate.
391   const auto* inst_to_propagate =
392       GetInstructionToPropagate(ir_context, block_id);
393   if (!inst_to_propagate) {
394     return false;
395   }
396 
397   // We should be able to insert |inst_to_propagate| into every predecessor of
398   // |block|.
399   return std::all_of(predecessors.begin(), predecessors.end(),
400                      [ir_context, inst_to_propagate](uint32_t predecessor_id) {
401                        return fuzzerutil::GetLastInsertBeforeInstruction(
402                                   ir_context, predecessor_id,
403                                   inst_to_propagate->opcode()) != nullptr;
404                      });
405 }
406 
GetFreshIds() const407 std::unordered_set<uint32_t> TransformationPropagateInstructionUp::GetFreshIds()
408     const {
409   std::unordered_set<uint32_t> result;
410   for (auto& pair : message_.predecessor_id_to_fresh_id()) {
411     result.insert(pair.second());
412   }
413   return result;
414 }
415 
416 }  // namespace fuzz
417 }  // namespace spvtools
418