1 // Copyright (c) 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/fuzzer_pass_add_equation_instructions.h"
16 
17 #include <vector>
18 
19 #include "source/fuzz/fuzzer_util.h"
20 #include "source/fuzz/transformation_equation_instruction.h"
21 
22 namespace spvtools {
23 namespace fuzz {
24 namespace {
25 
IsBitWidthSupported(opt::IRContext * ir_context,uint32_t bit_width)26 bool IsBitWidthSupported(opt::IRContext* ir_context, uint32_t bit_width) {
27   switch (bit_width) {
28     case 32:
29       return true;
30     case 64:
31       return ir_context->get_feature_mgr()->HasCapability(
32                  SpvCapabilityFloat64) &&
33              ir_context->get_feature_mgr()->HasCapability(SpvCapabilityInt64);
34     case 16:
35       return ir_context->get_feature_mgr()->HasCapability(
36                  SpvCapabilityFloat16) &&
37              ir_context->get_feature_mgr()->HasCapability(SpvCapabilityInt16);
38     default:
39       return false;
40   }
41 }
42 
43 }  // namespace
44 
FuzzerPassAddEquationInstructions(opt::IRContext * ir_context,TransformationContext * transformation_context,FuzzerContext * fuzzer_context,protobufs::TransformationSequence * transformations)45 FuzzerPassAddEquationInstructions::FuzzerPassAddEquationInstructions(
46     opt::IRContext* ir_context, TransformationContext* transformation_context,
47     FuzzerContext* fuzzer_context,
48     protobufs::TransformationSequence* transformations)
49     : FuzzerPass(ir_context, transformation_context, fuzzer_context,
50                  transformations) {}
51 
52 FuzzerPassAddEquationInstructions::~FuzzerPassAddEquationInstructions() =
53     default;
54 
Apply()55 void FuzzerPassAddEquationInstructions::Apply() {
56   ForEachInstructionWithInstructionDescriptor(
57       [this](opt::Function* function, opt::BasicBlock* block,
58              opt::BasicBlock::iterator inst_it,
59              const protobufs::InstructionDescriptor& instruction_descriptor) {
60         if (!GetFuzzerContext()->ChoosePercentage(
61                 GetFuzzerContext()->GetChanceOfAddingEquationInstruction())) {
62           return;
63         }
64 
65         // Check that it is OK to add an equation instruction before the given
66         // instruction in principle - e.g. check that this does not lead to
67         // inserting before an OpVariable or OpPhi instruction.  We use OpIAdd
68         // as an example opcode for this check, to be representative of *some*
69         // opcode that defines an equation, even though we may choose a
70         // different opcode below.
71         if (!fuzzerutil::CanInsertOpcodeBeforeInstruction(SpvOpIAdd, inst_it)) {
72           return;
73         }
74 
75         // Get all available instructions with result ids and types that are not
76         // OpUndef.
77         std::vector<opt::Instruction*> available_instructions =
78             FindAvailableInstructions(
79                 function, block, inst_it,
80                 [this](opt::IRContext* /*unused*/,
81                        opt::Instruction* instruction) -> bool {
82                   return instruction->result_id() && instruction->type_id() &&
83                          instruction->opcode() != SpvOpUndef &&
84                          !GetTransformationContext()
85                               ->GetFactManager()
86                               ->IdIsIrrelevant(instruction->result_id());
87                 });
88 
89         // Try the opcodes for which we know how to make ids at random until
90         // something works.
91         std::vector<SpvOp> candidate_opcodes = {
92             SpvOpIAdd,        SpvOpISub,        SpvOpLogicalNot, SpvOpSNegate,
93             SpvOpConvertUToF, SpvOpConvertSToF, SpvOpBitcast};
94         do {
95           auto opcode =
96               GetFuzzerContext()->RemoveAtRandomIndex(&candidate_opcodes);
97           switch (opcode) {
98             case SpvOpConvertSToF:
99             case SpvOpConvertUToF: {
100               std::vector<const opt::Instruction*> candidate_instructions;
101               for (const auto* inst :
102                    GetIntegerInstructions(available_instructions)) {
103                 const auto* type =
104                     GetIRContext()->get_type_mgr()->GetType(inst->type_id());
105                 assert(type && "|inst| has invalid type");
106 
107                 if (const auto* vector_type = type->AsVector()) {
108                   type = vector_type->element_type();
109                 }
110 
111                 if (IsBitWidthSupported(GetIRContext(),
112                                         type->AsInteger()->width())) {
113                   candidate_instructions.push_back(inst);
114                 }
115               }
116 
117               if (candidate_instructions.empty()) {
118                 break;
119               }
120 
121               const auto* operand =
122                   candidate_instructions[GetFuzzerContext()->RandomIndex(
123                       candidate_instructions)];
124 
125               const auto* type =
126                   GetIRContext()->get_type_mgr()->GetType(operand->type_id());
127               assert(type && "Operand has invalid type");
128 
129               // Make sure a result type exists in the module.
130               if (const auto* vector = type->AsVector()) {
131                 // We store element count in a separate variable since the
132                 // call FindOrCreate* functions below might invalidate
133                 // |vector| pointer.
134                 const auto element_count = vector->element_count();
135 
136                 FindOrCreateVectorType(
137                     FindOrCreateFloatType(
138                         vector->element_type()->AsInteger()->width()),
139                     element_count);
140               } else {
141                 FindOrCreateFloatType(type->AsInteger()->width());
142               }
143 
144               ApplyTransformation(TransformationEquationInstruction(
145                   GetFuzzerContext()->GetFreshId(), opcode,
146                   {operand->result_id()}, instruction_descriptor));
147               return;
148             }
149             case SpvOpBitcast: {
150               const auto candidate_instructions =
151                   GetNumericalInstructions(available_instructions);
152 
153               if (!candidate_instructions.empty()) {
154                 const auto* operand_inst =
155                     candidate_instructions[GetFuzzerContext()->RandomIndex(
156                         candidate_instructions)];
157                 const auto* operand_type =
158                     GetIRContext()->get_type_mgr()->GetType(
159                         operand_inst->type_id());
160                 assert(operand_type && "Operand instruction has invalid type");
161 
162                 // Make sure a result type exists in the module.
163                 //
164                 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3539):
165                 //  The only constraint on the types of OpBitcast's parameters
166                 //  is that they must have the same number of bits. Consider
167                 //  improving the code below to support this in full.
168                 if (const auto* vector = operand_type->AsVector()) {
169                   // We store element count in a separate variable since the
170                   // call FindOrCreate* functions below might invalidate
171                   // |vector| pointer.
172                   const auto element_count = vector->element_count();
173 
174                   uint32_t element_type_id;
175                   if (const auto* int_type =
176                           vector->element_type()->AsInteger()) {
177                     element_type_id = FindOrCreateFloatType(int_type->width());
178                   } else {
179                     assert(vector->element_type()->AsFloat() &&
180                            "Vector must have numerical elements");
181                     element_type_id = FindOrCreateIntegerType(
182                         vector->element_type()->AsFloat()->width(),
183                         GetFuzzerContext()->ChooseEven());
184                   }
185 
186                   FindOrCreateVectorType(element_type_id, element_count);
187                 } else if (const auto* int_type = operand_type->AsInteger()) {
188                   FindOrCreateFloatType(int_type->width());
189                 } else {
190                   assert(operand_type->AsFloat() &&
191                          "Operand is not a scalar of numerical type");
192                   FindOrCreateIntegerType(operand_type->AsFloat()->width(),
193                                           GetFuzzerContext()->ChooseEven());
194                 }
195 
196                 ApplyTransformation(TransformationEquationInstruction(
197                     GetFuzzerContext()->GetFreshId(), opcode,
198                     {operand_inst->result_id()}, instruction_descriptor));
199                 return;
200               }
201             } break;
202             case SpvOpIAdd:
203             case SpvOpISub: {
204               // Instructions of integer (scalar or vector) result type are
205               // suitable for these opcodes.
206               auto integer_instructions =
207                   GetIntegerInstructions(available_instructions);
208               if (!integer_instructions.empty()) {
209                 // There is at least one such instruction, so pick one at random
210                 // for the LHS of an equation.
211                 auto lhs = integer_instructions.at(
212                     GetFuzzerContext()->RandomIndex(integer_instructions));
213 
214                 // For the RHS, we can use any instruction with an integer
215                 // scalar/vector result type of the same number of components
216                 // and the same bit-width for the underlying integer type.
217 
218                 // Work out the element count and bit-width.
219                 auto lhs_type =
220                     GetIRContext()->get_type_mgr()->GetType(lhs->type_id());
221                 uint32_t lhs_element_count;
222                 uint32_t lhs_bit_width;
223                 if (lhs_type->AsVector()) {
224                   lhs_element_count = lhs_type->AsVector()->element_count();
225                   lhs_bit_width = lhs_type->AsVector()
226                                       ->element_type()
227                                       ->AsInteger()
228                                       ->width();
229                 } else {
230                   lhs_element_count = 1;
231                   lhs_bit_width = lhs_type->AsInteger()->width();
232                 }
233 
234                 // Get all the instructions that match on element count and
235                 // bit-width.
236                 auto candidate_rhs_instructions = RestrictToElementBitWidth(
237                     RestrictToVectorWidth(integer_instructions,
238                                           lhs_element_count),
239                     lhs_bit_width);
240 
241                 // Choose a RHS instruction at random; there is guaranteed to
242                 // be at least one choice as the LHS will be available.
243                 auto rhs = candidate_rhs_instructions.at(
244                     GetFuzzerContext()->RandomIndex(
245                         candidate_rhs_instructions));
246 
247                 // Add the equation instruction.
248                 ApplyTransformation(TransformationEquationInstruction(
249                     GetFuzzerContext()->GetFreshId(), opcode,
250                     {lhs->result_id(), rhs->result_id()},
251                     instruction_descriptor));
252                 return;
253               }
254               break;
255             }
256             case SpvOpLogicalNot: {
257               // Choose any available instruction of boolean scalar/vector
258               // result type and equate its negation with a fresh id.
259               auto boolean_instructions =
260                   GetBooleanInstructions(available_instructions);
261               if (!boolean_instructions.empty()) {
262                 ApplyTransformation(TransformationEquationInstruction(
263                     GetFuzzerContext()->GetFreshId(), opcode,
264                     {boolean_instructions
265                          .at(GetFuzzerContext()->RandomIndex(
266                              boolean_instructions))
267                          ->result_id()},
268                     instruction_descriptor));
269                 return;
270               }
271               break;
272             }
273             case SpvOpSNegate: {
274               // Similar to OpLogicalNot, but for signed integer negation.
275               auto integer_instructions =
276                   GetIntegerInstructions(available_instructions);
277               if (!integer_instructions.empty()) {
278                 ApplyTransformation(TransformationEquationInstruction(
279                     GetFuzzerContext()->GetFreshId(), opcode,
280                     {integer_instructions
281                          .at(GetFuzzerContext()->RandomIndex(
282                              integer_instructions))
283                          ->result_id()},
284                     instruction_descriptor));
285                 return;
286               }
287               break;
288             }
289             default:
290               assert(false && "Unexpected opcode.");
291               break;
292           }
293         } while (!candidate_opcodes.empty());
294         // Reaching here means that we did not manage to apply any
295         // transformation at this point of the module.
296       });
297 }
298 
299 std::vector<opt::Instruction*>
GetIntegerInstructions(const std::vector<opt::Instruction * > & instructions) const300 FuzzerPassAddEquationInstructions::GetIntegerInstructions(
301     const std::vector<opt::Instruction*>& instructions) const {
302   std::vector<opt::Instruction*> result;
303   for (auto& inst : instructions) {
304     auto type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
305     if (type->AsInteger() ||
306         (type->AsVector() && type->AsVector()->element_type()->AsInteger())) {
307       result.push_back(inst);
308     }
309   }
310   return result;
311 }
312 
313 std::vector<opt::Instruction*>
GetFloatInstructions(const std::vector<opt::Instruction * > & instructions) const314 FuzzerPassAddEquationInstructions::GetFloatInstructions(
315     const std::vector<opt::Instruction*>& instructions) const {
316   std::vector<opt::Instruction*> result;
317   for (auto& inst : instructions) {
318     auto type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
319     if (type->AsFloat() ||
320         (type->AsVector() && type->AsVector()->element_type()->AsFloat())) {
321       result.push_back(inst);
322     }
323   }
324   return result;
325 }
326 
327 std::vector<opt::Instruction*>
GetBooleanInstructions(const std::vector<opt::Instruction * > & instructions) const328 FuzzerPassAddEquationInstructions::GetBooleanInstructions(
329     const std::vector<opt::Instruction*>& instructions) const {
330   std::vector<opt::Instruction*> result;
331   for (auto& inst : instructions) {
332     auto type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
333     if (type->AsBool() ||
334         (type->AsVector() && type->AsVector()->element_type()->AsBool())) {
335       result.push_back(inst);
336     }
337   }
338   return result;
339 }
340 
341 std::vector<opt::Instruction*>
RestrictToVectorWidth(const std::vector<opt::Instruction * > & instructions,uint32_t vector_width) const342 FuzzerPassAddEquationInstructions::RestrictToVectorWidth(
343     const std::vector<opt::Instruction*>& instructions,
344     uint32_t vector_width) const {
345   std::vector<opt::Instruction*> result;
346   for (auto& inst : instructions) {
347     auto type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
348     // Get the vector width of |inst|, which is 1 if |inst| is a scalar and is
349     // otherwise derived from its vector type.
350     uint32_t other_vector_width =
351         type->AsVector() ? type->AsVector()->element_count() : 1;
352     // Keep |inst| if the vector widths match.
353     if (vector_width == other_vector_width) {
354       result.push_back(inst);
355     }
356   }
357   return result;
358 }
359 
360 std::vector<opt::Instruction*>
RestrictToElementBitWidth(const std::vector<opt::Instruction * > & instructions,uint32_t bit_width) const361 FuzzerPassAddEquationInstructions::RestrictToElementBitWidth(
362     const std::vector<opt::Instruction*>& instructions,
363     uint32_t bit_width) const {
364   std::vector<opt::Instruction*> result;
365   for (auto& inst : instructions) {
366     const opt::analysis::Type* type =
367         GetIRContext()->get_type_mgr()->GetType(inst->type_id());
368     if (type->AsVector()) {
369       type = type->AsVector()->element_type();
370     }
371     assert((type->AsInteger() || type->AsFloat()) &&
372            "Precondition: all input instructions must "
373            "have integer or float scalar or vector type.");
374     if ((type->AsInteger() && type->AsInteger()->width() == bit_width) ||
375         (type->AsFloat() && type->AsFloat()->width() == bit_width)) {
376       result.push_back(inst);
377     }
378   }
379   return result;
380 }
381 
382 std::vector<opt::Instruction*>
GetNumericalInstructions(const std::vector<opt::Instruction * > & instructions) const383 FuzzerPassAddEquationInstructions::GetNumericalInstructions(
384     const std::vector<opt::Instruction*>& instructions) const {
385   std::vector<opt::Instruction*> result;
386 
387   for (auto* inst : instructions) {
388     const auto* type = GetIRContext()->get_type_mgr()->GetType(inst->type_id());
389     assert(type && "Instruction has invalid type");
390 
391     if (const auto* vector_type = type->AsVector()) {
392       type = vector_type->element_type();
393     }
394 
395     if (!type->AsInteger() && !type->AsFloat()) {
396       // Only numerical scalars or vectors of numerical components are
397       // supported.
398       continue;
399     }
400 
401     if (!IsBitWidthSupported(GetIRContext(), type->AsInteger()
402                                                  ? type->AsInteger()->width()
403                                                  : type->AsFloat()->width())) {
404       continue;
405     }
406 
407     result.push_back(inst);
408   }
409 
410   return result;
411 }
412 
413 }  // namespace fuzz
414 }  // namespace spvtools
415