1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h"
16 
17 #include "gtest/gtest.h"
18 #include "source/fuzz/fuzzer_util.h"
19 #include "source/fuzz/id_use_descriptor.h"
20 #include "source/fuzz/instruction_descriptor.h"
21 #include "test/fuzz/fuzz_test_util.h"
22 
23 namespace spvtools {
24 namespace fuzz {
25 namespace {
26 
TEST(TransformationReplaceBooleanConstantWithConstantBinaryTest,BasicReplacements)27 TEST(TransformationReplaceBooleanConstantWithConstantBinaryTest,
28      BasicReplacements) {
29   // The test came from the following pseudo-GLSL, where int64 and uint64 denote
30   // 64-bit integer types (they were replaced with int and uint during
31   // translation to SPIR-V, and the generated SPIR-V has been doctored to
32   // accommodate them).
33   //
34   // #version 450
35   //
36   // void main() {
37   //   double d1, d2;
38   //   d1 = 1.0;
39   //   d2 = 2.0;
40   //   float f1, f2;
41   //   f1 = 4.0;
42   //   f2 = 8.0;
43   //   int i1, i2;
44   //   i1 = 100;
45   //   i2 = 200;
46   //
47   //   uint u1, u2;
48   //   u1 = 300u;
49   //   u2 = 400u;
50   //
51   //   int64 i64_1, i64_2;
52   //   i64_1 = 500;
53   //   i64_2 = 600;
54   //
55   //   uint64 u64_1, u64_2;
56   //   u64_1 = 700u;
57   //   u64_2 = 800u;
58   //
59   //   bool b, c, d, e;
60   //   b = true;
61   //   c = false;
62   //   d = true || c;
63   //   c = c && false;
64   // }
65   std::string shader = R"(
66                OpCapability Shader
67                OpCapability Float64
68                OpCapability Int64
69           %1 = OpExtInstImport "GLSL.std.450"
70                OpMemoryModel Logical GLSL450
71                OpEntryPoint Fragment %4 "main"
72                OpExecutionMode %4 OriginUpperLeft
73                OpSource GLSL 450
74                OpName %4 "main"
75                OpName %8 "d1"
76                OpName %10 "d2"
77                OpName %14 "f1"
78                OpName %16 "f2"
79                OpName %20 "i1"
80                OpName %22 "i2"
81                OpName %26 "u1"
82                OpName %28 "u2"
83                OpName %30 "i64_1"
84                OpName %32 "i64_2"
85                OpName %34 "u64_1"
86                OpName %36 "u64_2"
87                OpName %40 "b"
88                OpName %42 "c"
89                OpName %44 "d"
90           %2 = OpTypeVoid
91           %3 = OpTypeFunction %2
92           %6 = OpTypeFloat 64
93           %7 = OpTypePointer Function %6
94           %9 = OpConstant %6 1
95          %11 = OpConstant %6 2
96          %12 = OpTypeFloat 32
97          %13 = OpTypePointer Function %12
98          %15 = OpConstant %12 4
99          %17 = OpConstant %12 8
100          %18 = OpTypeInt 32 1
101          %60 = OpTypeInt 64 1
102          %61 = OpTypePointer Function %60
103          %19 = OpTypePointer Function %18
104          %21 = OpConstant %18 -100
105          %23 = OpConstant %18 200
106          %24 = OpTypeInt 32 0
107          %62 = OpTypeInt 64 0
108          %63 = OpTypePointer Function %62
109          %25 = OpTypePointer Function %24
110          %27 = OpConstant %24 300
111          %29 = OpConstant %24 400
112          %31 = OpConstant %60 -600
113          %33 = OpConstant %60 -500
114          %35 = OpConstant %62 700
115          %37 = OpConstant %62 800
116          %38 = OpTypeBool
117          %39 = OpTypePointer Function %38
118          %41 = OpConstantTrue %38
119          %43 = OpConstantFalse %38
120           %4 = OpFunction %2 None %3
121           %5 = OpLabel
122           %8 = OpVariable %7 Function
123          %10 = OpVariable %7 Function
124          %14 = OpVariable %13 Function
125          %16 = OpVariable %13 Function
126          %20 = OpVariable %19 Function
127          %22 = OpVariable %19 Function
128          %26 = OpVariable %25 Function
129          %28 = OpVariable %25 Function
130          %30 = OpVariable %61 Function
131          %32 = OpVariable %61 Function
132          %34 = OpVariable %63 Function
133          %36 = OpVariable %63 Function
134          %40 = OpVariable %39 Function
135          %42 = OpVariable %39 Function
136          %44 = OpVariable %39 Function
137                OpStore %8 %9
138                OpStore %10 %11
139                OpStore %14 %15
140                OpStore %16 %17
141                OpStore %20 %21
142                OpStore %22 %23
143                OpStore %26 %27
144                OpStore %28 %29
145                OpStore %30 %31
146                OpStore %32 %33
147                OpStore %34 %35
148                OpStore %36 %37
149                OpStore %40 %41
150                OpStore %42 %43
151          %45 = OpLoad %38 %42
152          %46 = OpLogicalOr %38 %41 %45
153                OpStore %44 %46
154          %47 = OpLoad %38 %42
155          %48 = OpLogicalAnd %38 %47 %43
156                OpStore %42 %48
157                OpReturn
158                OpFunctionEnd
159   )";
160 
161   const auto env = SPV_ENV_UNIVERSAL_1_3;
162   const auto consumer = nullptr;
163   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
164   spvtools::ValidatorOptions validator_options;
165   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
166                                                kConsoleMessageConsumer));
167   TransformationContext transformation_context(
168       MakeUnique<FactManager>(context.get()), validator_options);
169   std::vector<protobufs::IdUseDescriptor> uses_of_true = {
170       MakeIdUseDescriptor(41, MakeInstructionDescriptor(44, SpvOpStore, 12), 1),
171       MakeIdUseDescriptor(41, MakeInstructionDescriptor(46, SpvOpLogicalOr, 0),
172                           0)};
173 
174   std::vector<protobufs::IdUseDescriptor> uses_of_false = {
175       MakeIdUseDescriptor(43, MakeInstructionDescriptor(44, SpvOpStore, 13), 1),
176       MakeIdUseDescriptor(43, MakeInstructionDescriptor(48, SpvOpLogicalAnd, 0),
177                           1)};
178 
179   const uint32_t fresh_id = 100;
180 
181   std::vector<SpvOp> fp_gt_opcodes = {
182       SpvOpFOrdGreaterThan, SpvOpFOrdGreaterThanEqual, SpvOpFUnordGreaterThan,
183       SpvOpFUnordGreaterThanEqual};
184 
185   std::vector<SpvOp> fp_lt_opcodes = {SpvOpFOrdLessThan, SpvOpFOrdLessThanEqual,
186                                       SpvOpFUnordLessThan,
187                                       SpvOpFUnordLessThanEqual};
188 
189   std::vector<SpvOp> int_gt_opcodes = {SpvOpSGreaterThan,
190                                        SpvOpSGreaterThanEqual};
191 
192   std::vector<SpvOp> int_lt_opcodes = {SpvOpSLessThan, SpvOpSLessThanEqual};
193 
194   std::vector<SpvOp> uint_gt_opcodes = {SpvOpUGreaterThan,
195                                         SpvOpUGreaterThanEqual};
196 
197   std::vector<SpvOp> uint_lt_opcodes = {SpvOpULessThan, SpvOpULessThanEqual};
198 
199 #define CHECK_OPERATOR(USE_DESCRIPTOR, LHS_ID, RHS_ID, OPCODE, FRESH_ID) \
200   ASSERT_TRUE(TransformationReplaceBooleanConstantWithConstantBinary(    \
201                   USE_DESCRIPTOR, LHS_ID, RHS_ID, OPCODE, FRESH_ID)      \
202                   .IsApplicable(context.get(), transformation_context)); \
203   ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(   \
204                    USE_DESCRIPTOR, RHS_ID, LHS_ID, OPCODE, FRESH_ID)     \
205                    .IsApplicable(context.get(), transformation_context));
206 
207 #define CHECK_TRANSFORMATION_APPLICABILITY(GT_OPCODES, LT_OPCODES, SMALL_ID, \
208                                            LARGE_ID)                         \
209   for (auto gt_opcode : GT_OPCODES) {                                        \
210     for (auto& true_use : uses_of_true) {                                    \
211       CHECK_OPERATOR(true_use, LARGE_ID, SMALL_ID, gt_opcode, fresh_id);     \
212     }                                                                        \
213     for (auto& false_use : uses_of_false) {                                  \
214       CHECK_OPERATOR(false_use, SMALL_ID, LARGE_ID, gt_opcode, fresh_id);    \
215     }                                                                        \
216   }                                                                          \
217   for (auto lt_opcode : LT_OPCODES) {                                        \
218     for (auto& true_use : uses_of_true) {                                    \
219       CHECK_OPERATOR(true_use, SMALL_ID, LARGE_ID, lt_opcode, fresh_id);     \
220     }                                                                        \
221     for (auto& false_use : uses_of_false) {                                  \
222       CHECK_OPERATOR(false_use, LARGE_ID, SMALL_ID, lt_opcode, fresh_id);    \
223     }                                                                        \
224   }
225 
226   // Float
227   { CHECK_TRANSFORMATION_APPLICABILITY(fp_gt_opcodes, fp_lt_opcodes, 15, 17); }
228 
229   // Double
230   { CHECK_TRANSFORMATION_APPLICABILITY(fp_gt_opcodes, fp_lt_opcodes, 9, 11); }
231 
232   // Int32
233   {
234     CHECK_TRANSFORMATION_APPLICABILITY(int_gt_opcodes, int_lt_opcodes, 21, 23);
235   }
236 
237   // Int64
238   {
239     CHECK_TRANSFORMATION_APPLICABILITY(int_gt_opcodes, int_lt_opcodes, 31, 33);
240   }
241 
242   // Uint32
243   {
244     CHECK_TRANSFORMATION_APPLICABILITY(uint_gt_opcodes, uint_lt_opcodes, 27,
245                                        29);
246   }
247 
248   // Uint64
249   {
250     CHECK_TRANSFORMATION_APPLICABILITY(uint_gt_opcodes, uint_lt_opcodes, 35,
251                                        37);
252   }
253 
254   // Target id is not fresh
255   ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(
256                    uses_of_true[0], 15, 17, SpvOpFOrdLessThan, 15)
257                    .IsApplicable(context.get(), transformation_context));
258 
259   // LHS id does not exist
260   ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(
261                    uses_of_true[0], 300, 17, SpvOpFOrdLessThan, 200)
262                    .IsApplicable(context.get(), transformation_context));
263 
264   // RHS id does not exist
265   ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(
266                    uses_of_true[0], 15, 300, SpvOpFOrdLessThan, 200)
267                    .IsApplicable(context.get(), transformation_context));
268 
269   // LHS and RHS ids do not match type
270   ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(
271                    uses_of_true[0], 11, 17, SpvOpFOrdLessThan, 200)
272                    .IsApplicable(context.get(), transformation_context));
273 
274   // Opcode not appropriate
275   ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(
276                    uses_of_true[0], 15, 17, SpvOpFDiv, 200)
277                    .IsApplicable(context.get(), transformation_context));
278 
279   auto replace_true_with_double_comparison =
280       TransformationReplaceBooleanConstantWithConstantBinary(
281           uses_of_true[0], 11, 9, SpvOpFUnordGreaterThan, 100);
282   auto replace_true_with_uint32_comparison =
283       TransformationReplaceBooleanConstantWithConstantBinary(
284           uses_of_true[1], 27, 29, SpvOpULessThanEqual, 101);
285   auto replace_false_with_float_comparison =
286       TransformationReplaceBooleanConstantWithConstantBinary(
287           uses_of_false[0], 17, 15, SpvOpFOrdLessThan, 102);
288   auto replace_false_with_sint64_comparison =
289       TransformationReplaceBooleanConstantWithConstantBinary(
290           uses_of_false[1], 33, 31, SpvOpSLessThan, 103);
291 
292   ASSERT_TRUE(replace_true_with_double_comparison.IsApplicable(
293       context.get(), transformation_context));
294   ApplyAndCheckFreshIds(replace_true_with_double_comparison, context.get(),
295                         &transformation_context);
296   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
297                                                kConsoleMessageConsumer));
298   ASSERT_TRUE(replace_true_with_uint32_comparison.IsApplicable(
299       context.get(), transformation_context));
300   ApplyAndCheckFreshIds(replace_true_with_uint32_comparison, context.get(),
301                         &transformation_context);
302   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
303                                                kConsoleMessageConsumer));
304   ASSERT_TRUE(replace_false_with_float_comparison.IsApplicable(
305       context.get(), transformation_context));
306   ApplyAndCheckFreshIds(replace_false_with_float_comparison, context.get(),
307                         &transformation_context);
308   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
309                                                kConsoleMessageConsumer));
310   ASSERT_TRUE(replace_false_with_sint64_comparison.IsApplicable(
311       context.get(), transformation_context));
312   ApplyAndCheckFreshIds(replace_false_with_sint64_comparison, context.get(),
313                         &transformation_context);
314   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
315                                                kConsoleMessageConsumer));
316 
317   std::string after = R"(
318                OpCapability Shader
319                OpCapability Float64
320                OpCapability Int64
321           %1 = OpExtInstImport "GLSL.std.450"
322                OpMemoryModel Logical GLSL450
323                OpEntryPoint Fragment %4 "main"
324                OpExecutionMode %4 OriginUpperLeft
325                OpSource GLSL 450
326                OpName %4 "main"
327                OpName %8 "d1"
328                OpName %10 "d2"
329                OpName %14 "f1"
330                OpName %16 "f2"
331                OpName %20 "i1"
332                OpName %22 "i2"
333                OpName %26 "u1"
334                OpName %28 "u2"
335                OpName %30 "i64_1"
336                OpName %32 "i64_2"
337                OpName %34 "u64_1"
338                OpName %36 "u64_2"
339                OpName %40 "b"
340                OpName %42 "c"
341                OpName %44 "d"
342           %2 = OpTypeVoid
343           %3 = OpTypeFunction %2
344           %6 = OpTypeFloat 64
345           %7 = OpTypePointer Function %6
346           %9 = OpConstant %6 1
347          %11 = OpConstant %6 2
348          %12 = OpTypeFloat 32
349          %13 = OpTypePointer Function %12
350          %15 = OpConstant %12 4
351          %17 = OpConstant %12 8
352          %18 = OpTypeInt 32 1
353          %60 = OpTypeInt 64 1
354          %61 = OpTypePointer Function %60
355          %19 = OpTypePointer Function %18
356          %21 = OpConstant %18 -100
357          %23 = OpConstant %18 200
358          %24 = OpTypeInt 32 0
359          %62 = OpTypeInt 64 0
360          %63 = OpTypePointer Function %62
361          %25 = OpTypePointer Function %24
362          %27 = OpConstant %24 300
363          %29 = OpConstant %24 400
364          %31 = OpConstant %60 -600
365          %33 = OpConstant %60 -500
366          %35 = OpConstant %62 700
367          %37 = OpConstant %62 800
368          %38 = OpTypeBool
369          %39 = OpTypePointer Function %38
370          %41 = OpConstantTrue %38
371          %43 = OpConstantFalse %38
372           %4 = OpFunction %2 None %3
373           %5 = OpLabel
374           %8 = OpVariable %7 Function
375          %10 = OpVariable %7 Function
376          %14 = OpVariable %13 Function
377          %16 = OpVariable %13 Function
378          %20 = OpVariable %19 Function
379          %22 = OpVariable %19 Function
380          %26 = OpVariable %25 Function
381          %28 = OpVariable %25 Function
382          %30 = OpVariable %61 Function
383          %32 = OpVariable %61 Function
384          %34 = OpVariable %63 Function
385          %36 = OpVariable %63 Function
386          %40 = OpVariable %39 Function
387          %42 = OpVariable %39 Function
388          %44 = OpVariable %39 Function
389                OpStore %8 %9
390                OpStore %10 %11
391                OpStore %14 %15
392                OpStore %16 %17
393                OpStore %20 %21
394                OpStore %22 %23
395                OpStore %26 %27
396                OpStore %28 %29
397                OpStore %30 %31
398                OpStore %32 %33
399                OpStore %34 %35
400                OpStore %36 %37
401         %100 = OpFUnordGreaterThan %38 %11 %9
402                OpStore %40 %100
403         %102 = OpFOrdLessThan %38 %17 %15
404                OpStore %42 %102
405          %45 = OpLoad %38 %42
406         %101 = OpULessThanEqual %38 %27 %29
407          %46 = OpLogicalOr %38 %101 %45
408                OpStore %44 %46
409          %47 = OpLoad %38 %42
410         %103 = OpSLessThan %38 %33 %31
411          %48 = OpLogicalAnd %38 %47 %103
412                OpStore %42 %48
413                OpReturn
414                OpFunctionEnd
415   )";
416   ASSERT_TRUE(IsEqual(env, after, context.get()));
417 
418   if (std::numeric_limits<double>::has_quiet_NaN) {
419     double quiet_nan_double = std::numeric_limits<double>::quiet_NaN();
420     uint32_t words[2];
421     memcpy(words, &quiet_nan_double, sizeof(double));
422     opt::Instruction::OperandList operands = {
423         {SPV_OPERAND_TYPE_LITERAL_INTEGER, {words[0]}},
424         {SPV_OPERAND_TYPE_LITERAL_INTEGER, {words[1]}}};
425     context->module()->AddGlobalValue(MakeUnique<opt::Instruction>(
426         context.get(), SpvOpConstant, 6, 200, operands));
427     fuzzerutil::UpdateModuleIdBound(context.get(), 200);
428     ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(
429         context.get(), validator_options, kConsoleMessageConsumer));
430     // The transformation is not applicable because %200 is NaN.
431     ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(
432                      uses_of_true[0], 11, 200, SpvOpFOrdLessThan, 300)
433                      .IsApplicable(context.get(), transformation_context));
434   }
435   if (std::numeric_limits<double>::has_infinity) {
436     double positive_infinity_double = std::numeric_limits<double>::infinity();
437     uint32_t words[2];
438     memcpy(words, &positive_infinity_double, sizeof(double));
439     opt::Instruction::OperandList operands = {
440         {SPV_OPERAND_TYPE_LITERAL_INTEGER, {words[0]}},
441         {SPV_OPERAND_TYPE_LITERAL_INTEGER, {words[1]}}};
442     context->module()->AddGlobalValue(MakeUnique<opt::Instruction>(
443         context.get(), SpvOpConstant, 6, 201, operands));
444     fuzzerutil::UpdateModuleIdBound(context.get(), 201);
445     ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(
446         context.get(), validator_options, kConsoleMessageConsumer));
447     // Even though the double constant %11 is less than the infinity %201, the
448     // transformation is restricted to only apply to finite values.
449     ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(
450                      uses_of_true[0], 11, 201, SpvOpFOrdLessThan, 300)
451                      .IsApplicable(context.get(), transformation_context));
452   }
453   if (std::numeric_limits<float>::has_infinity) {
454     float positive_infinity_float = std::numeric_limits<float>::infinity();
455     float negative_infinity_float = -1 * positive_infinity_float;
456     uint32_t words_positive_infinity[1];
457     uint32_t words_negative_infinity[1];
458     memcpy(words_positive_infinity, &positive_infinity_float, sizeof(float));
459     memcpy(words_negative_infinity, &negative_infinity_float, sizeof(float));
460     opt::Instruction::OperandList operands_positive_infinity = {
461         {SPV_OPERAND_TYPE_LITERAL_INTEGER, {words_positive_infinity[0]}}};
462     context->module()->AddGlobalValue(MakeUnique<opt::Instruction>(
463         context.get(), SpvOpConstant, 12, 202, operands_positive_infinity));
464     fuzzerutil::UpdateModuleIdBound(context.get(), 202);
465     opt::Instruction::OperandList operands = {
466         {SPV_OPERAND_TYPE_LITERAL_INTEGER, {words_negative_infinity[0]}}};
467     context->module()->AddGlobalValue(MakeUnique<opt::Instruction>(
468         context.get(), SpvOpConstant, 12, 203, operands));
469     fuzzerutil::UpdateModuleIdBound(context.get(), 203);
470     ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(
471         context.get(), validator_options, kConsoleMessageConsumer));
472     // Even though the negative infinity at %203 is less than the positive
473     // infinity %202, the transformation is restricted to only apply to finite
474     // values.
475     ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(
476                      uses_of_true[0], 203, 202, SpvOpFOrdLessThan, 300)
477                      .IsApplicable(context.get(), transformation_context));
478   }
479 }
480 
TEST(TransformationReplaceBooleanConstantWithConstantBinaryTest,MergeInstructions)481 TEST(TransformationReplaceBooleanConstantWithConstantBinaryTest,
482      MergeInstructions) {
483   // The test came from the following GLSL:
484   //
485   // void main() {
486   //   int x = 1;
487   //   int y = 2;
488   //   if (true) {
489   //     x = 2;
490   //   }
491   //   while(false) {
492   //     y = 2;
493   //   }
494   // }
495 
496   std::string shader = R"(
497                OpCapability Shader
498           %1 = OpExtInstImport "GLSL.std.450"
499                OpMemoryModel Logical GLSL450
500                OpEntryPoint Fragment %4 "main"
501                OpExecutionMode %4 OriginUpperLeft
502                OpSource GLSL 450
503                OpName %4 "main"
504                OpName %8 "x"
505                OpName %10 "y"
506           %2 = OpTypeVoid
507           %3 = OpTypeFunction %2
508           %6 = OpTypeInt 32 1
509           %7 = OpTypePointer Function %6
510           %9 = OpConstant %6 1
511          %11 = OpConstant %6 2
512          %12 = OpTypeBool
513          %13 = OpConstantTrue %12
514          %21 = OpConstantFalse %12
515           %4 = OpFunction %2 None %3
516           %5 = OpLabel
517           %8 = OpVariable %7 Function
518          %10 = OpVariable %7 Function
519                OpStore %8 %9
520                OpStore %10 %11
521                OpSelectionMerge %15 None
522                OpBranchConditional %13 %14 %15
523          %14 = OpLabel
524                OpStore %8 %11
525                OpBranch %15
526          %15 = OpLabel
527                OpBranch %16
528          %16 = OpLabel
529                OpLoopMerge %18 %19 None
530                OpBranchConditional %21 %17 %18
531          %17 = OpLabel
532                OpStore %10 %11
533                OpBranch %19
534          %19 = OpLabel
535                OpBranch %16
536          %18 = OpLabel
537                OpReturn
538                OpFunctionEnd
539   )";
540 
541   const auto env = SPV_ENV_UNIVERSAL_1_3;
542   const auto consumer = nullptr;
543   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
544   spvtools::ValidatorOptions validator_options;
545   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
546                                                kConsoleMessageConsumer));
547   TransformationContext transformation_context(
548       MakeUnique<FactManager>(context.get()), validator_options);
549   auto use_of_true_in_if = MakeIdUseDescriptor(
550       13, MakeInstructionDescriptor(10, SpvOpBranchConditional, 0), 0);
551   auto use_of_false_in_while = MakeIdUseDescriptor(
552       21, MakeInstructionDescriptor(16, SpvOpBranchConditional, 0), 0);
553 
554   auto replacement_1 = TransformationReplaceBooleanConstantWithConstantBinary(
555       use_of_true_in_if, 9, 11, SpvOpSLessThan, 100);
556   auto replacement_2 = TransformationReplaceBooleanConstantWithConstantBinary(
557       use_of_false_in_while, 9, 11, SpvOpSGreaterThanEqual, 101);
558 
559   ASSERT_TRUE(
560       replacement_1.IsApplicable(context.get(), transformation_context));
561   ApplyAndCheckFreshIds(replacement_1, context.get(), &transformation_context);
562   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
563                                                kConsoleMessageConsumer));
564 
565   ASSERT_TRUE(
566       replacement_2.IsApplicable(context.get(), transformation_context));
567   ApplyAndCheckFreshIds(replacement_2, context.get(), &transformation_context);
568   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
569                                                kConsoleMessageConsumer));
570 
571   std::string after = R"(
572                OpCapability Shader
573           %1 = OpExtInstImport "GLSL.std.450"
574                OpMemoryModel Logical GLSL450
575                OpEntryPoint Fragment %4 "main"
576                OpExecutionMode %4 OriginUpperLeft
577                OpSource GLSL 450
578                OpName %4 "main"
579                OpName %8 "x"
580                OpName %10 "y"
581           %2 = OpTypeVoid
582           %3 = OpTypeFunction %2
583           %6 = OpTypeInt 32 1
584           %7 = OpTypePointer Function %6
585           %9 = OpConstant %6 1
586          %11 = OpConstant %6 2
587          %12 = OpTypeBool
588          %13 = OpConstantTrue %12
589          %21 = OpConstantFalse %12
590           %4 = OpFunction %2 None %3
591           %5 = OpLabel
592           %8 = OpVariable %7 Function
593          %10 = OpVariable %7 Function
594                OpStore %8 %9
595                OpStore %10 %11
596         %100 = OpSLessThan %12 %9 %11
597                OpSelectionMerge %15 None
598                OpBranchConditional %100 %14 %15
599          %14 = OpLabel
600                OpStore %8 %11
601                OpBranch %15
602          %15 = OpLabel
603                OpBranch %16
604          %16 = OpLabel
605         %101 = OpSGreaterThanEqual %12 %9 %11
606                OpLoopMerge %18 %19 None
607                OpBranchConditional %101 %17 %18
608          %17 = OpLabel
609                OpStore %10 %11
610                OpBranch %19
611          %19 = OpLabel
612                OpBranch %16
613          %18 = OpLabel
614                OpReturn
615                OpFunctionEnd
616   )";
617 
618   ASSERT_TRUE(IsEqual(env, after, context.get()));
619 }
620 
TEST(TransformationReplaceBooleanConstantWithConstantBinaryTest,OpPhi)621 TEST(TransformationReplaceBooleanConstantWithConstantBinaryTest, OpPhi) {
622   // Hand-written SPIR-V to check applicability of the transformation on an
623   // OpPhi argument.
624 
625   std::string reference_shader = R"(
626                OpCapability Shader
627           %1 = OpExtInstImport "GLSL.std.450"
628                OpMemoryModel Logical GLSL450
629                OpEntryPoint Vertex %10 "main"
630 
631 ; Types
632           %2 = OpTypeVoid
633           %3 = OpTypeFunction %2
634           %4 = OpTypeInt 32 0
635           %5 = OpTypeBool
636 
637 ; Constants
638           %6 = OpConstant %4 0
639           %7 = OpConstant %4 1
640           %8 = OpConstantTrue %5
641           %9 = OpConstantFalse %5
642 
643 ; main function
644          %10 = OpFunction %2 None %3
645          %11 = OpLabel
646                OpSelectionMerge %13 None
647                OpBranchConditional %8 %12 %13
648          %12 = OpLabel
649                OpBranch %13
650          %13 = OpLabel
651          %14 = OpPhi %5 %8 %11 %9 %12
652                OpReturn
653                OpFunctionEnd
654   )";
655 
656   const auto env = SPV_ENV_UNIVERSAL_1_3;
657   const auto consumer = nullptr;
658   const auto context =
659       BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
660   spvtools::ValidatorOptions validator_options;
661   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
662                                                kConsoleMessageConsumer));
663   TransformationContext transformation_context(
664       MakeUnique<FactManager>(context.get()), validator_options);
665   auto instruction_descriptor = MakeInstructionDescriptor(14, SpvOpPhi, 0);
666   auto id_use_descriptor = MakeIdUseDescriptor(8, instruction_descriptor, 0);
667   auto transformation = TransformationReplaceBooleanConstantWithConstantBinary(
668       id_use_descriptor, 6, 7, SpvOpULessThan, 15);
669   ASSERT_TRUE(
670       transformation.IsApplicable(context.get(), transformation_context));
671   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
672 
673   std::string variant_shader = R"(
674                OpCapability Shader
675           %1 = OpExtInstImport "GLSL.std.450"
676                OpMemoryModel Logical GLSL450
677                OpEntryPoint Vertex %10 "main"
678 
679 ; Types
680           %2 = OpTypeVoid
681           %3 = OpTypeFunction %2
682           %4 = OpTypeInt 32 0
683           %5 = OpTypeBool
684 
685 ; Constants
686           %6 = OpConstant %4 0
687           %7 = OpConstant %4 1
688           %8 = OpConstantTrue %5
689           %9 = OpConstantFalse %5
690 
691 ; main function
692          %10 = OpFunction %2 None %3
693          %11 = OpLabel
694          %15 = OpULessThan %5 %6 %7
695                OpSelectionMerge %13 None
696                OpBranchConditional %8 %12 %13
697          %12 = OpLabel
698                OpBranch %13
699          %13 = OpLabel
700          %14 = OpPhi %5 %15 %11 %9 %12
701                OpReturn
702                OpFunctionEnd
703   )";
704 
705   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
706                                                kConsoleMessageConsumer));
707   ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
708 }
709 
TEST(TransformationReplaceBooleanConstantWithConstantBinaryTest,DoNotReplaceVariableInitializer)710 TEST(TransformationReplaceBooleanConstantWithConstantBinaryTest,
711      DoNotReplaceVariableInitializer) {
712   std::string shader = R"(
713                OpCapability Shader
714           %1 = OpExtInstImport "GLSL.std.450"
715                OpMemoryModel Logical GLSL450
716                OpEntryPoint Fragment %4 "main"
717                OpExecutionMode %4 OriginUpperLeft
718                OpSource ESSL 310
719                OpName %4 "main"
720           %2 = OpTypeVoid
721           %3 = OpTypeFunction %2
722           %6 = OpTypeBool
723           %7 = OpTypePointer Function %6
724           %9 = OpConstantTrue %6
725          %10 = OpTypeInt 32 1
726          %13 = OpConstant %10 0
727          %15 = OpConstant %10 1
728           %4 = OpFunction %2 None %3
729           %5 = OpLabel
730          %50 = OpVariable %7 Function %9
731                OpReturn
732                OpFunctionEnd
733   )";
734 
735   const auto env = SPV_ENV_UNIVERSAL_1_3;
736   const auto consumer = nullptr;
737   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
738   spvtools::ValidatorOptions validator_options;
739   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
740                                                kConsoleMessageConsumer));
741   TransformationContext transformation_context(
742       MakeUnique<FactManager>(context.get()), validator_options);
743   ASSERT_FALSE(TransformationReplaceBooleanConstantWithConstantBinary(
744                    MakeIdUseDescriptor(
745                        9, MakeInstructionDescriptor(50, SpvOpVariable, 0), 1),
746                    13, 15, SpvOpSLessThan, 100)
747                    .IsApplicable(context.get(), transformation_context));
748 }
749 
750 }  // namespace
751 }  // namespace fuzz
752 }  // namespace spvtools
753