1 // Copyright (c) 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/fuzzer_pass_add_equation_instructions.h"
16 
17 #include <vector>
18 
19 #include "source/fuzz/fuzzer_util.h"
20 #include "source/fuzz/transformation_equation_instruction.h"
21 
22 namespace spvtools {
23 namespace fuzz {
24 namespace {
25 
IsBitWidthSupported(opt::IRContext * ir_context,uint32_t bit_width)26 bool IsBitWidthSupported(opt::IRContext* ir_context, uint32_t bit_width) {
27   switch (bit_width) {
28     case 32:
29       return true;
30     case 64:
31       return ir_context->get_feature_mgr()->HasCapability(
32                  SpvCapabilityFloat64) &&
33              ir_context->get_feature_mgr()->HasCapability(SpvCapabilityInt64);
34     case 16:
35       return ir_context->get_feature_mgr()->HasCapability(
36                  SpvCapabilityFloat16) &&
37              ir_context->get_feature_mgr()->HasCapability(SpvCapabilityInt16);
38     default:
39       return false;
40   }
41 }
42 
43 }  // namespace
44 
FuzzerPassAddEquationInstructions(opt::IRContext * ir_context,TransformationContext * transformation_context,FuzzerContext * fuzzer_context,protobufs::TransformationSequence * transformations,bool ignore_inapplicable_transformations)45 FuzzerPassAddEquationInstructions::FuzzerPassAddEquationInstructions(
46     opt::IRContext* ir_context, TransformationContext* transformation_context,
47     FuzzerContext* fuzzer_context,
48     protobufs::TransformationSequence* transformations,
49     bool ignore_inapplicable_transformations)
50     : FuzzerPass(ir_context, transformation_context, fuzzer_context,
51                  transformations, ignore_inapplicable_transformations) {}
52 
Apply()53 void FuzzerPassAddEquationInstructions::Apply() {
54   ForEachInstructionWithInstructionDescriptor(
55       [this](opt::Function* function, opt::BasicBlock* block,
56              opt::BasicBlock::iterator inst_it,
57              const protobufs::InstructionDescriptor& instruction_descriptor) {
58         if (!GetFuzzerContext()->ChoosePercentage(
59                 GetFuzzerContext()->GetChanceOfAddingEquationInstruction())) {
60           return;
61         }
62 
63         // Check that it is OK to add an equation instruction before the given
64         // instruction in principle - e.g. check that this does not lead to
65         // inserting before an OpVariable or OpPhi instruction.  We use OpIAdd
66         // as an example opcode for this check, to be representative of *some*
67         // opcode that defines an equation, even though we may choose a
68         // different opcode below.
69         if (!fuzzerutil::CanInsertOpcodeBeforeInstruction(SpvOpIAdd, inst_it)) {
70           return;
71         }
72 
73         // Get all available instructions with result ids and types that are not
74         // OpUndef.
75         std::vector<opt::Instruction*> available_instructions =
76             FindAvailableInstructions(
77                 function, block, inst_it,
78                 [this](opt::IRContext* /*unused*/,
79                        opt::Instruction* instruction) -> bool {
80                   return instruction->result_id() && instruction->type_id() &&
81                          instruction->opcode() != SpvOpUndef &&
82                          !GetTransformationContext()
83                               ->GetFactManager()
84                               ->IdIsIrrelevant(instruction->result_id());
85                 });
86 
87         // Try the opcodes for which we know how to make ids at random until
88         // something works.
89         std::vector<SpvOp> candidate_opcodes = {
90             SpvOpIAdd,        SpvOpISub,        SpvOpLogicalNot, SpvOpSNegate,
91             SpvOpConvertUToF, SpvOpConvertSToF, SpvOpBitcast};
92         do {
93           auto opcode =
94               GetFuzzerContext()->RemoveAtRandomIndex(&candidate_opcodes);
95           switch (opcode) {
96             case SpvOpConvertSToF:
97             case SpvOpConvertUToF: {
98               std::vector<const opt::Instruction*> candidate_instructions;
99               for (const auto* inst :
100                    GetIntegerInstructions(available_instructions)) {
101                 const auto* type =
102                     GetIRContext()->get_type_mgr()->GetType(inst->type_id());
103                 assert(type && "|inst| has invalid type");
104 
105                 if (const auto* vector_type = type->AsVector()) {
106                   type = vector_type->element_type();
107                 }
108 
109                 if (IsBitWidthSupported(GetIRContext(),
110                                         type->AsInteger()->width())) {
111                   candidate_instructions.push_back(inst);
112                 }
113               }
114 
115               if (candidate_instructions.empty()) {
116                 break;
117               }
118 
119               const auto* operand =
120                   candidate_instructions[GetFuzzerContext()->RandomIndex(
121                       candidate_instructions)];
122 
123               const auto* type =
124                   GetIRContext()->get_type_mgr()->GetType(operand->type_id());
125               assert(type && "Operand has invalid type");
126 
127               // Make sure a result type exists in the module.
128               if (const auto* vector = type->AsVector()) {
129                 // We store element count in a separate variable since the
130                 // call FindOrCreate* functions below might invalidate
131                 // |vector| pointer.
132                 const auto element_count = vector->element_count();
133 
134                 FindOrCreateVectorType(
135                     FindOrCreateFloatType(
136                         vector->element_type()->AsInteger()->width()),
137                     element_count);
138               } else {
139                 FindOrCreateFloatType(type->AsInteger()->width());
140               }
141 
142               ApplyTransformation(TransformationEquationInstruction(
143                   GetFuzzerContext()->GetFreshId(), opcode,
144                   {operand->result_id()}, instruction_descriptor));
145               return;
146             }
147             case SpvOpBitcast: {
148               const auto candidate_instructions =
149                   GetNumericalInstructions(available_instructions);
150 
151               if (!candidate_instructions.empty()) {
152                 const auto* operand_inst =
153                     candidate_instructions[GetFuzzerContext()->RandomIndex(
154                         candidate_instructions)];
155                 const auto* operand_type =
156                     GetIRContext()->get_type_mgr()->GetType(
157                         operand_inst->type_id());
158                 assert(operand_type && "Operand instruction has invalid type");
159 
160                 // Make sure a result type exists in the module.
161                 //
162                 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3539):
163                 //  The only constraint on the types of OpBitcast's parameters
164                 //  is that they must have the same number of bits. Consider
165                 //  improving the code below to support this in full.
166                 if (const auto* vector = operand_type->AsVector()) {
167                   // We store element count in a separate variable since the
168                   // call FindOrCreate* functions below might invalidate
169                   // |vector| pointer.
170                   const auto element_count = vector->element_count();
171 
172                   uint32_t element_type_id;
173                   if (const auto* int_type =
174                           vector->element_type()->AsInteger()) {
175                     element_type_id = FindOrCreateFloatType(int_type->width());
176                   } else {
177                     assert(vector->element_type()->AsFloat() &&
178                            "Vector must have numerical elements");
179                     element_type_id = FindOrCreateIntegerType(
180                         vector->element_type()->AsFloat()->width(),
181                         GetFuzzerContext()->ChooseEven());
182                   }
183 
184                   FindOrCreateVectorType(element_type_id, element_count);
185                 } else if (const auto* int_type = operand_type->AsInteger()) {
186                   FindOrCreateFloatType(int_type->width());
187                 } else {
188                   assert(operand_type->AsFloat() &&
189                          "Operand is not a scalar of numerical type");
190                   FindOrCreateIntegerType(operand_type->AsFloat()->width(),
191                                           GetFuzzerContext()->ChooseEven());
192                 }
193 
194                 ApplyTransformation(TransformationEquationInstruction(
195                     GetFuzzerContext()->GetFreshId(), opcode,
196                     {operand_inst->result_id()}, instruction_descriptor));
197                 return;
198               }
199             } break;
200             case SpvOpIAdd:
201             case SpvOpISub: {
202               // Instructions of integer (scalar or vector) result type are
203               // suitable for these opcodes.
204               auto integer_instructions =
205                   GetIntegerInstructions(available_instructions);
206               if (!integer_instructions.empty()) {
207                 // There is at least one such instruction, so pick one at random
208                 // for the LHS of an equation.
209                 auto lhs = integer_instructions.at(
210                     GetFuzzerContext()->RandomIndex(integer_instructions));
211 
212                 // For the RHS, we can use any instruction with an integer
213                 // scalar/vector result type of the same number of components
214                 // and the same bit-width for the underlying integer type.
215 
216                 // Work out the element count and bit-width.
217                 auto lhs_type =
218                     GetIRContext()->get_type_mgr()->GetType(lhs->type_id());
219                 uint32_t lhs_element_count;
220                 uint32_t lhs_bit_width;
221                 if (lhs_type->AsVector()) {
222                   lhs_element_count = lhs_type->AsVector()->element_count();
223                   lhs_bit_width = lhs_type->AsVector()
224                                       ->element_type()
225                                       ->AsInteger()
226                                       ->width();
227                 } else {
228                   lhs_element_count = 1;
229                   lhs_bit_width = lhs_type->AsInteger()->width();
230                 }
231 
232                 // Get all the instructions that match on element count and
233                 // bit-width.
234                 auto candidate_rhs_instructions = RestrictToElementBitWidth(
235                     RestrictToVectorWidth(integer_instructions,
236                                           lhs_element_count),
237                     lhs_bit_width);
238 
239                 // Choose a RHS instruction at random; there is guaranteed to
240                 // be at least one choice as the LHS will be available.
241                 auto rhs = candidate_rhs_instructions.at(
242                     GetFuzzerContext()->RandomIndex(
243                         candidate_rhs_instructions));
244 
245                 // Add the equation instruction.
246                 ApplyTransformation(TransformationEquationInstruction(
247                     GetFuzzerContext()->GetFreshId(), opcode,
248                     {lhs->result_id(), rhs->result_id()},
249                     instruction_descriptor));
250                 return;
251               }
252               break;
253             }
254             case SpvOpLogicalNot: {
255               // Choose any available instruction of boolean scalar/vector
256               // result type and equate its negation with a fresh id.
257               auto boolean_instructions =
258                   GetBooleanInstructions(available_instructions);
259               if (!boolean_instructions.empty()) {
260                 ApplyTransformation(TransformationEquationInstruction(
261                     GetFuzzerContext()->GetFreshId(), opcode,
262                     {boolean_instructions
263                          .at(GetFuzzerContext()->RandomIndex(
264                              boolean_instructions))
265                          ->result_id()},
266                     instruction_descriptor));
267                 return;
268               }
269               break;
270             }
271             case SpvOpSNegate: {
272               // Similar to OpLogicalNot, but for signed integer negation.
273               auto integer_instructions =
274                   GetIntegerInstructions(available_instructions);
275               if (!integer_instructions.empty()) {
276                 ApplyTransformation(TransformationEquationInstruction(
277                     GetFuzzerContext()->GetFreshId(), opcode,
278                     {integer_instructions
279                          .at(GetFuzzerContext()->RandomIndex(
280                              integer_instructions))
281                          ->result_id()},
282                     instruction_descriptor));
283                 return;
284               }
285               break;
286             }
287             default:
288               assert(false && "Unexpected opcode.");
289               break;
290           }
291         } while (!candidate_opcodes.empty());
292         // Reaching here means that we did not manage to apply any
293         // transformation at this point of the module.
294       });
295 }
296 
297 std::vector<opt::Instruction*>
GetIntegerInstructions(const std::vector<opt::Instruction * > & instructions) const298 FuzzerPassAddEquationInstructions::GetIntegerInstructions(
299     const std::vector<opt::Instruction*>& instructions) const {
300   std::vector<opt::Instruction*> result;
301   for (auto& inst : instructions) {
302     auto type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
303     if (type->AsInteger() ||
304         (type->AsVector() && type->AsVector()->element_type()->AsInteger())) {
305       result.push_back(inst);
306     }
307   }
308   return result;
309 }
310 
311 std::vector<opt::Instruction*>
GetFloatInstructions(const std::vector<opt::Instruction * > & instructions) const312 FuzzerPassAddEquationInstructions::GetFloatInstructions(
313     const std::vector<opt::Instruction*>& instructions) const {
314   std::vector<opt::Instruction*> result;
315   for (auto& inst : instructions) {
316     auto type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
317     if (type->AsFloat() ||
318         (type->AsVector() && type->AsVector()->element_type()->AsFloat())) {
319       result.push_back(inst);
320     }
321   }
322   return result;
323 }
324 
325 std::vector<opt::Instruction*>
GetBooleanInstructions(const std::vector<opt::Instruction * > & instructions) const326 FuzzerPassAddEquationInstructions::GetBooleanInstructions(
327     const std::vector<opt::Instruction*>& instructions) const {
328   std::vector<opt::Instruction*> result;
329   for (auto& inst : instructions) {
330     auto type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
331     if (type->AsBool() ||
332         (type->AsVector() && type->AsVector()->element_type()->AsBool())) {
333       result.push_back(inst);
334     }
335   }
336   return result;
337 }
338 
339 std::vector<opt::Instruction*>
RestrictToVectorWidth(const std::vector<opt::Instruction * > & instructions,uint32_t vector_width) const340 FuzzerPassAddEquationInstructions::RestrictToVectorWidth(
341     const std::vector<opt::Instruction*>& instructions,
342     uint32_t vector_width) const {
343   std::vector<opt::Instruction*> result;
344   for (auto& inst : instructions) {
345     auto type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
346     // Get the vector width of |inst|, which is 1 if |inst| is a scalar and is
347     // otherwise derived from its vector type.
348     uint32_t other_vector_width =
349         type->AsVector() ? type->AsVector()->element_count() : 1;
350     // Keep |inst| if the vector widths match.
351     if (vector_width == other_vector_width) {
352       result.push_back(inst);
353     }
354   }
355   return result;
356 }
357 
358 std::vector<opt::Instruction*>
RestrictToElementBitWidth(const std::vector<opt::Instruction * > & instructions,uint32_t bit_width) const359 FuzzerPassAddEquationInstructions::RestrictToElementBitWidth(
360     const std::vector<opt::Instruction*>& instructions,
361     uint32_t bit_width) const {
362   std::vector<opt::Instruction*> result;
363   for (auto& inst : instructions) {
364     const opt::analysis::Type* type =
365         GetIRContext()->get_type_mgr()->GetType(inst->type_id());
366     if (type->AsVector()) {
367       type = type->AsVector()->element_type();
368     }
369     assert((type->AsInteger() || type->AsFloat()) &&
370            "Precondition: all input instructions must "
371            "have integer or float scalar or vector type.");
372     if ((type->AsInteger() && type->AsInteger()->width() == bit_width) ||
373         (type->AsFloat() && type->AsFloat()->width() == bit_width)) {
374       result.push_back(inst);
375     }
376   }
377   return result;
378 }
379 
380 std::vector<opt::Instruction*>
GetNumericalInstructions(const std::vector<opt::Instruction * > & instructions) const381 FuzzerPassAddEquationInstructions::GetNumericalInstructions(
382     const std::vector<opt::Instruction*>& instructions) const {
383   std::vector<opt::Instruction*> result;
384 
385   for (auto* inst : instructions) {
386     const auto* type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
387     assert(type && "Instruction has invalid type");
388 
389     if (const auto* vector_type = type->AsVector()) {
390       type = vector_type->element_type();
391     }
392 
393     if (!type->AsInteger() && !type->AsFloat()) {
394       // Only numerical scalars or vectors of numerical components are
395       // supported.
396       continue;
397     }
398 
399     if (!IsBitWidthSupported(GetIRContext(), type->AsInteger()
400                                                  ? type->AsInteger()->width()
401                                                  : type->AsFloat()->width())) {
402       continue;
403     }
404 
405     result.push_back(inst);
406   }
407 
408   return result;
409 }
410 
411 }  // namespace fuzz
412 }  // namespace spvtools
413