1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h"
16
17 #include <cmath>
18
19 #include "source/fuzz/fuzzer_util.h"
20 #include "source/fuzz/id_use_descriptor.h"
21
22 namespace spvtools {
23 namespace fuzz {
24
25 namespace {
26
27 // Given floating-point values |lhs| and |rhs|, and a floating-point binary
28 // operator |binop|, returns true if it is certain that 'lhs binop rhs'
29 // evaluates to |required_value|.
30 template <typename T>
float_binop_evaluates_to(T lhs,T rhs,SpvOp binop,bool required_value)31 bool float_binop_evaluates_to(T lhs, T rhs, SpvOp binop, bool required_value) {
32 // Infinity and NaN values are conservatively treated as out of scope.
33 if (!std::isfinite(lhs) || !std::isfinite(rhs)) {
34 return false;
35 }
36 bool binop_result;
37 // The following captures the binary operators that spirv-fuzz can actually
38 // generate when turning a boolean constant into a binary expression.
39 switch (binop) {
40 case SpvOpFOrdGreaterThanEqual:
41 case SpvOpFUnordGreaterThanEqual:
42 binop_result = (lhs >= rhs);
43 break;
44 case SpvOpFOrdGreaterThan:
45 case SpvOpFUnordGreaterThan:
46 binop_result = (lhs > rhs);
47 break;
48 case SpvOpFOrdLessThanEqual:
49 case SpvOpFUnordLessThanEqual:
50 binop_result = (lhs <= rhs);
51 break;
52 case SpvOpFOrdLessThan:
53 case SpvOpFUnordLessThan:
54 binop_result = (lhs < rhs);
55 break;
56 default:
57 return false;
58 }
59 return binop_result == required_value;
60 }
61
62 // Analogous to 'float_binop_evaluates_to', but for signed int values.
63 template <typename T>
signed_int_binop_evaluates_to(T lhs,T rhs,SpvOp binop,bool required_value)64 bool signed_int_binop_evaluates_to(T lhs, T rhs, SpvOp binop,
65 bool required_value) {
66 bool binop_result;
67 switch (binop) {
68 case SpvOpSGreaterThanEqual:
69 binop_result = (lhs >= rhs);
70 break;
71 case SpvOpSGreaterThan:
72 binop_result = (lhs > rhs);
73 break;
74 case SpvOpSLessThanEqual:
75 binop_result = (lhs <= rhs);
76 break;
77 case SpvOpSLessThan:
78 binop_result = (lhs < rhs);
79 break;
80 default:
81 return false;
82 }
83 return binop_result == required_value;
84 }
85
86 // Analogous to 'float_binop_evaluates_to', but for unsigned int values.
87 template <typename T>
unsigned_int_binop_evaluates_to(T lhs,T rhs,SpvOp binop,bool required_value)88 bool unsigned_int_binop_evaluates_to(T lhs, T rhs, SpvOp binop,
89 bool required_value) {
90 bool binop_result;
91 switch (binop) {
92 case SpvOpUGreaterThanEqual:
93 binop_result = (lhs >= rhs);
94 break;
95 case SpvOpUGreaterThan:
96 binop_result = (lhs > rhs);
97 break;
98 case SpvOpULessThanEqual:
99 binop_result = (lhs <= rhs);
100 break;
101 case SpvOpULessThan:
102 binop_result = (lhs < rhs);
103 break;
104 default:
105 return false;
106 }
107 return binop_result == required_value;
108 }
109
110 } // namespace
111
112 TransformationReplaceBooleanConstantWithConstantBinary::
TransformationReplaceBooleanConstantWithConstantBinary(protobufs::TransformationReplaceBooleanConstantWithConstantBinary message)113 TransformationReplaceBooleanConstantWithConstantBinary(
114 protobufs::TransformationReplaceBooleanConstantWithConstantBinary
115 message)
116 : message_(std::move(message)) {}
117
118 TransformationReplaceBooleanConstantWithConstantBinary::
TransformationReplaceBooleanConstantWithConstantBinary(const protobufs::IdUseDescriptor & id_use_descriptor,uint32_t lhs_id,uint32_t rhs_id,SpvOp comparison_opcode,uint32_t fresh_id_for_binary_operation)119 TransformationReplaceBooleanConstantWithConstantBinary(
120 const protobufs::IdUseDescriptor& id_use_descriptor, uint32_t lhs_id,
121 uint32_t rhs_id, SpvOp comparison_opcode,
122 uint32_t fresh_id_for_binary_operation) {
123 *message_.mutable_id_use_descriptor() = id_use_descriptor;
124 message_.set_lhs_id(lhs_id);
125 message_.set_rhs_id(rhs_id);
126 message_.set_opcode(comparison_opcode);
127 message_.set_fresh_id_for_binary_operation(fresh_id_for_binary_operation);
128 }
129
IsApplicable(opt::IRContext * ir_context,const TransformationContext &) const130 bool TransformationReplaceBooleanConstantWithConstantBinary::IsApplicable(
131 opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
132 // The id for the binary result must be fresh
133 if (!fuzzerutil::IsFreshId(ir_context,
134 message_.fresh_id_for_binary_operation())) {
135 return false;
136 }
137
138 // The used id must be for a boolean constant
139 auto boolean_constant = ir_context->get_def_use_mgr()->GetDef(
140 message_.id_use_descriptor().id_of_interest());
141 if (!boolean_constant) {
142 return false;
143 }
144 if (!(boolean_constant->opcode() == SpvOpConstantFalse ||
145 boolean_constant->opcode() == SpvOpConstantTrue)) {
146 return false;
147 }
148
149 // The left-hand-side id must correspond to a constant instruction.
150 auto lhs_constant_inst =
151 ir_context->get_def_use_mgr()->GetDef(message_.lhs_id());
152 if (!lhs_constant_inst) {
153 return false;
154 }
155 if (lhs_constant_inst->opcode() != SpvOpConstant) {
156 return false;
157 }
158
159 // The right-hand-side id must correspond to a constant instruction.
160 auto rhs_constant_inst =
161 ir_context->get_def_use_mgr()->GetDef(message_.rhs_id());
162 if (!rhs_constant_inst) {
163 return false;
164 }
165 if (rhs_constant_inst->opcode() != SpvOpConstant) {
166 return false;
167 }
168
169 // The left- and right-hand side instructions must have the same type.
170 if (lhs_constant_inst->type_id() != rhs_constant_inst->type_id()) {
171 return false;
172 }
173
174 // The expression 'LHS opcode RHS' must evaluate to the boolean constant.
175 auto lhs_constant =
176 ir_context->get_constant_mgr()->FindDeclaredConstant(message_.lhs_id());
177 auto rhs_constant =
178 ir_context->get_constant_mgr()->FindDeclaredConstant(message_.rhs_id());
179 bool expected_result = (boolean_constant->opcode() == SpvOpConstantTrue);
180
181 const auto binary_opcode = static_cast<SpvOp>(message_.opcode());
182
183 // We consider the floating point, signed and unsigned integer cases
184 // separately. In each case the logic is very similar.
185 if (lhs_constant->AsFloatConstant()) {
186 assert(rhs_constant->AsFloatConstant() &&
187 "Both constants should be of the same type.");
188 if (lhs_constant->type()->AsFloat()->width() == 32) {
189 if (!float_binop_evaluates_to(lhs_constant->GetFloat(),
190 rhs_constant->GetFloat(), binary_opcode,
191 expected_result)) {
192 return false;
193 }
194 } else {
195 assert(lhs_constant->type()->AsFloat()->width() == 64);
196 if (!float_binop_evaluates_to(lhs_constant->GetDouble(),
197 rhs_constant->GetDouble(), binary_opcode,
198 expected_result)) {
199 return false;
200 }
201 }
202 } else {
203 assert(lhs_constant->AsIntConstant() && "Constants should be in or float.");
204 assert(rhs_constant->AsIntConstant() &&
205 "Both constants should be of the same type.");
206 if (lhs_constant->type()->AsInteger()->IsSigned()) {
207 if (lhs_constant->type()->AsInteger()->width() == 32) {
208 if (!signed_int_binop_evaluates_to(lhs_constant->GetS32(),
209 rhs_constant->GetS32(),
210 binary_opcode, expected_result)) {
211 return false;
212 }
213 } else {
214 assert(lhs_constant->type()->AsInteger()->width() == 64);
215 if (!signed_int_binop_evaluates_to(lhs_constant->GetS64(),
216 rhs_constant->GetS64(),
217 binary_opcode, expected_result)) {
218 return false;
219 }
220 }
221 } else {
222 if (lhs_constant->type()->AsInteger()->width() == 32) {
223 if (!unsigned_int_binop_evaluates_to(lhs_constant->GetU32(),
224 rhs_constant->GetU32(),
225 binary_opcode, expected_result)) {
226 return false;
227 }
228 } else {
229 assert(lhs_constant->type()->AsInteger()->width() == 64);
230 if (!unsigned_int_binop_evaluates_to(lhs_constant->GetU64(),
231 rhs_constant->GetU64(),
232 binary_opcode, expected_result)) {
233 return false;
234 }
235 }
236 }
237 }
238
239 // The id use descriptor must identify some instruction
240 auto instruction =
241 FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
242 if (instruction == nullptr) {
243 return false;
244 }
245
246 // The instruction must not be an OpVariable, because (a) we cannot insert
247 // a binary operator before an OpVariable, but in any case (b) the
248 // constant we would be replacing is the initializer constant of the
249 // OpVariable, and this cannot be the result of a binary operation.
250 if (instruction->opcode() == SpvOpVariable) {
251 return false;
252 }
253
254 return true;
255 }
256
Apply(opt::IRContext * ir_context,TransformationContext * transformation_context) const257 void TransformationReplaceBooleanConstantWithConstantBinary::Apply(
258 opt::IRContext* ir_context,
259 TransformationContext* transformation_context) const {
260 ApplyWithResult(ir_context, transformation_context);
261 }
262
263 opt::Instruction*
ApplyWithResult(opt::IRContext * ir_context,TransformationContext *) const264 TransformationReplaceBooleanConstantWithConstantBinary::ApplyWithResult(
265 opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
266 opt::analysis::Bool bool_type;
267 opt::Instruction::OperandList operands = {
268 {SPV_OPERAND_TYPE_ID, {message_.lhs_id()}},
269 {SPV_OPERAND_TYPE_ID, {message_.rhs_id()}}};
270 auto binary_instruction = MakeUnique<opt::Instruction>(
271 ir_context, static_cast<SpvOp>(message_.opcode()),
272 ir_context->get_type_mgr()->GetId(&bool_type),
273 message_.fresh_id_for_binary_operation(), operands);
274 opt::Instruction* result = binary_instruction.get();
275 auto instruction_containing_constant_use =
276 FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
277 auto instruction_before_which_to_insert = instruction_containing_constant_use;
278
279 // If |instruction_before_which_to_insert| is an OpPhi instruction,
280 // then |binary_instruction| will be inserted into the parent block associated
281 // with the OpPhi variable operand.
282 if (instruction_containing_constant_use->opcode() == SpvOpPhi) {
283 instruction_before_which_to_insert =
284 ir_context->cfg()
285 ->block(instruction_containing_constant_use->GetSingleWordInOperand(
286 message_.id_use_descriptor().in_operand_index() + 1))
287 ->terminator();
288 }
289
290 // We want to insert the new instruction before the instruction that contains
291 // the use of the boolean, but we need to go backwards one more instruction if
292 // the using instruction is preceded by a merge instruction.
293 {
294 opt::Instruction* previous_node =
295 instruction_before_which_to_insert->PreviousNode();
296 if (previous_node && (previous_node->opcode() == SpvOpLoopMerge ||
297 previous_node->opcode() == SpvOpSelectionMerge)) {
298 instruction_before_which_to_insert = previous_node;
299 }
300 }
301
302 instruction_before_which_to_insert->InsertBefore(
303 std::move(binary_instruction));
304 instruction_containing_constant_use->SetInOperand(
305 message_.id_use_descriptor().in_operand_index(),
306 {message_.fresh_id_for_binary_operation()});
307 fuzzerutil::UpdateModuleIdBound(ir_context,
308 message_.fresh_id_for_binary_operation());
309 ir_context->InvalidateAnalysesExceptFor(
310 opt::IRContext::Analysis::kAnalysisNone);
311 return result;
312 }
313
314 protobufs::Transformation
ToMessage() const315 TransformationReplaceBooleanConstantWithConstantBinary::ToMessage() const {
316 protobufs::Transformation result;
317 *result.mutable_replace_boolean_constant_with_constant_binary() = message_;
318 return result;
319 }
320
321 std::unordered_set<uint32_t>
GetFreshIds() const322 TransformationReplaceBooleanConstantWithConstantBinary::GetFreshIds() const {
323 return {message_.fresh_id_for_binary_operation()};
324 }
325
326 } // namespace fuzz
327 } // namespace spvtools
328