1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/fold.h"
16 
17 #include <cassert>
18 #include <cstdint>
19 #include <vector>
20 
21 #include "source/opt/const_folding_rules.h"
22 #include "source/opt/def_use_manager.h"
23 #include "source/opt/folding_rules.h"
24 #include "source/opt/ir_builder.h"
25 #include "source/opt/ir_context.h"
26 
27 namespace spvtools {
28 namespace opt {
29 namespace {
30 
31 #ifndef INT32_MIN
32 #define INT32_MIN (-2147483648)
33 #endif
34 
35 #ifndef INT32_MAX
36 #define INT32_MAX 2147483647
37 #endif
38 
39 #ifndef UINT32_MAX
40 #define UINT32_MAX 0xffffffff /* 4294967295U */
41 #endif
42 
43 }  // namespace
44 
UnaryOperate(SpvOp opcode,uint32_t operand) const45 uint32_t InstructionFolder::UnaryOperate(SpvOp opcode, uint32_t operand) const {
46   switch (opcode) {
47     // Arthimetics
48     case SpvOp::SpvOpSNegate: {
49       int32_t s_operand = static_cast<int32_t>(operand);
50       if (s_operand == std::numeric_limits<int32_t>::min()) {
51         return s_operand;
52       }
53       return -s_operand;
54     }
55     case SpvOp::SpvOpNot:
56       return ~operand;
57     case SpvOp::SpvOpLogicalNot:
58       return !static_cast<bool>(operand);
59     case SpvOp::SpvOpUConvert:
60       return operand;
61     case SpvOp::SpvOpSConvert:
62       return operand;
63     default:
64       assert(false &&
65              "Unsupported unary operation for OpSpecConstantOp instruction");
66       return 0u;
67   }
68 }
69 
BinaryOperate(SpvOp opcode,uint32_t a,uint32_t b) const70 uint32_t InstructionFolder::BinaryOperate(SpvOp opcode, uint32_t a,
71                                           uint32_t b) const {
72   switch (opcode) {
73     // Arthimetics
74     case SpvOp::SpvOpIAdd:
75       return a + b;
76     case SpvOp::SpvOpISub:
77       return a - b;
78     case SpvOp::SpvOpIMul:
79       return a * b;
80     case SpvOp::SpvOpUDiv:
81       if (b != 0) {
82         return a / b;
83       } else {
84         // Dividing by 0 is undefined, so we will just pick 0.
85         return 0;
86       }
87     case SpvOp::SpvOpSDiv:
88       if (b != 0u) {
89         return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
90       } else {
91         // Dividing by 0 is undefined, so we will just pick 0.
92         return 0;
93       }
94     case SpvOp::SpvOpSRem: {
95       // The sign of non-zero result comes from the first operand: a. This is
96       // guaranteed by C++11 rules for integer division operator. The division
97       // result is rounded toward zero, so the result of '%' has the sign of
98       // the first operand.
99       if (b != 0u) {
100         return static_cast<int32_t>(a) % static_cast<int32_t>(b);
101       } else {
102         // Remainder when dividing with 0 is undefined, so we will just pick 0.
103         return 0;
104       }
105     }
106     case SpvOp::SpvOpSMod: {
107       // The sign of non-zero result comes from the second operand: b
108       if (b != 0u) {
109         int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
110         int32_t b_prim = static_cast<int32_t>(b);
111         return (rem + b_prim) % b_prim;
112       } else {
113         // Mod with 0 is undefined, so we will just pick 0.
114         return 0;
115       }
116     }
117     case SpvOp::SpvOpUMod:
118       if (b != 0u) {
119         return (a % b);
120       } else {
121         // Mod with 0 is undefined, so we will just pick 0.
122         return 0;
123       }
124 
125     // Shifting
126     case SpvOp::SpvOpShiftRightLogical:
127       if (b >= 32) {
128         // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.
129         // When |b| == 32, doing the shift in C++ in undefined, but the result
130         // will be 0, so just return that value.
131         return 0;
132       }
133       return a >> b;
134     case SpvOp::SpvOpShiftRightArithmetic:
135       if (b > 32) {
136         // This is undefined behaviour.  Choose 0 for consistency.
137         return 0;
138       }
139       if (b == 32) {
140         // Doing the shift in C++ is undefined, but the result is defined in the
141         // spir-v spec.  Find that value another way.
142         if (static_cast<int32_t>(a) >= 0) {
143           return 0;
144         } else {
145           return static_cast<uint32_t>(-1);
146         }
147       }
148       return (static_cast<int32_t>(a)) >> b;
149     case SpvOp::SpvOpShiftLeftLogical:
150       if (b >= 32) {
151         // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.
152         // When |b| == 32, doing the shift in C++ in undefined, but the result
153         // will be 0, so just return that value.
154         return 0;
155       }
156       return a << b;
157 
158     // Bitwise operations
159     case SpvOp::SpvOpBitwiseOr:
160       return a | b;
161     case SpvOp::SpvOpBitwiseAnd:
162       return a & b;
163     case SpvOp::SpvOpBitwiseXor:
164       return a ^ b;
165 
166     // Logical
167     case SpvOp::SpvOpLogicalEqual:
168       return (static_cast<bool>(a)) == (static_cast<bool>(b));
169     case SpvOp::SpvOpLogicalNotEqual:
170       return (static_cast<bool>(a)) != (static_cast<bool>(b));
171     case SpvOp::SpvOpLogicalOr:
172       return (static_cast<bool>(a)) || (static_cast<bool>(b));
173     case SpvOp::SpvOpLogicalAnd:
174       return (static_cast<bool>(a)) && (static_cast<bool>(b));
175 
176     // Comparison
177     case SpvOp::SpvOpIEqual:
178       return a == b;
179     case SpvOp::SpvOpINotEqual:
180       return a != b;
181     case SpvOp::SpvOpULessThan:
182       return a < b;
183     case SpvOp::SpvOpSLessThan:
184       return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
185     case SpvOp::SpvOpUGreaterThan:
186       return a > b;
187     case SpvOp::SpvOpSGreaterThan:
188       return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
189     case SpvOp::SpvOpULessThanEqual:
190       return a <= b;
191     case SpvOp::SpvOpSLessThanEqual:
192       return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
193     case SpvOp::SpvOpUGreaterThanEqual:
194       return a >= b;
195     case SpvOp::SpvOpSGreaterThanEqual:
196       return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
197     default:
198       assert(false &&
199              "Unsupported binary operation for OpSpecConstantOp instruction");
200       return 0u;
201   }
202 }
203 
TernaryOperate(SpvOp opcode,uint32_t a,uint32_t b,uint32_t c) const204 uint32_t InstructionFolder::TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b,
205                                            uint32_t c) const {
206   switch (opcode) {
207     case SpvOp::SpvOpSelect:
208       return (static_cast<bool>(a)) ? b : c;
209     default:
210       assert(false &&
211              "Unsupported ternary operation for OpSpecConstantOp instruction");
212       return 0u;
213   }
214 }
215 
OperateWords(SpvOp opcode,const std::vector<uint32_t> & operand_words) const216 uint32_t InstructionFolder::OperateWords(
217     SpvOp opcode, const std::vector<uint32_t>& operand_words) const {
218   switch (operand_words.size()) {
219     case 1:
220       return UnaryOperate(opcode, operand_words.front());
221     case 2:
222       return BinaryOperate(opcode, operand_words.front(), operand_words.back());
223     case 3:
224       return TernaryOperate(opcode, operand_words[0], operand_words[1],
225                             operand_words[2]);
226     default:
227       assert(false && "Invalid number of operands");
228       return 0;
229   }
230 }
231 
FoldInstructionInternal(Instruction * inst) const232 bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
233   auto identity_map = [](uint32_t id) { return id; };
234   Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
235   if (folded_inst != nullptr) {
236     inst->SetOpcode(SpvOpCopyObject);
237     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
238     return true;
239   }
240 
241   analysis::ConstantManager* const_manager = context_->get_constant_mgr();
242   std::vector<const analysis::Constant*> constants =
243       const_manager->GetOperandConstants(inst);
244 
245   for (const FoldingRule& rule :
246        GetFoldingRules().GetRulesForInstruction(inst)) {
247     if (rule(context_, inst, constants)) {
248       return true;
249     }
250   }
251   return false;
252 }
253 
254 // Returns the result of performing an operation on scalar constant operands.
255 // This function extracts the operand values as 32 bit words and returns the
256 // result in 32 bit word. Scalar constants with longer than 32-bit width are
257 // not accepted in this function.
FoldScalars(SpvOp opcode,const std::vector<const analysis::Constant * > & operands) const258 uint32_t InstructionFolder::FoldScalars(
259     SpvOp opcode,
260     const std::vector<const analysis::Constant*>& operands) const {
261   assert(IsFoldableOpcode(opcode) &&
262          "Unhandled instruction opcode in FoldScalars");
263   std::vector<uint32_t> operand_values_in_raw_words;
264   for (const auto& operand : operands) {
265     if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
266       const auto& scalar_words = scalar->words();
267       assert(scalar_words.size() == 1 &&
268              "Scalar constants with longer than 32-bit width are not allowed "
269              "in FoldScalars()");
270       operand_values_in_raw_words.push_back(scalar_words.front());
271     } else if (operand->AsNullConstant()) {
272       operand_values_in_raw_words.push_back(0u);
273     } else {
274       assert(false &&
275              "FoldScalars() only accepts ScalarConst or NullConst type of "
276              "constant");
277     }
278   }
279   return OperateWords(opcode, operand_values_in_raw_words);
280 }
281 
FoldBinaryIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const282 bool InstructionFolder::FoldBinaryIntegerOpToConstant(
283     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
284     uint32_t* result) const {
285   SpvOp opcode = inst->opcode();
286   analysis::ConstantManager* const_manger = context_->get_constant_mgr();
287 
288   uint32_t ids[2];
289   const analysis::IntConstant* constants[2];
290   for (uint32_t i = 0; i < 2; i++) {
291     const Operand* operand = &inst->GetInOperand(i);
292     if (operand->type != SPV_OPERAND_TYPE_ID) {
293       return false;
294     }
295     ids[i] = id_map(operand->words[0]);
296     const analysis::Constant* constant =
297         const_manger->FindDeclaredConstant(ids[i]);
298     constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
299   }
300 
301   switch (opcode) {
302     // Arthimetics
303     case SpvOp::SpvOpIMul:
304       for (uint32_t i = 0; i < 2; i++) {
305         if (constants[i] != nullptr && constants[i]->IsZero()) {
306           *result = 0;
307           return true;
308         }
309       }
310       break;
311     case SpvOp::SpvOpUDiv:
312     case SpvOp::SpvOpSDiv:
313     case SpvOp::SpvOpSRem:
314     case SpvOp::SpvOpSMod:
315     case SpvOp::SpvOpUMod:
316       // This changes undefined behaviour (ie divide by 0) into a 0.
317       for (uint32_t i = 0; i < 2; i++) {
318         if (constants[i] != nullptr && constants[i]->IsZero()) {
319           *result = 0;
320           return true;
321         }
322       }
323       break;
324 
325     // Shifting
326     case SpvOp::SpvOpShiftRightLogical:
327     case SpvOp::SpvOpShiftLeftLogical:
328       if (constants[1] != nullptr) {
329         // When shifting by a value larger than the size of the result, the
330         // result is undefined.  We are setting the undefined behaviour to a
331         // result of 0.  If the shift amount is the same as the size of the
332         // result, then the result is defined, and it 0.
333         uint32_t shift_amount = constants[1]->GetU32BitValue();
334         if (shift_amount >= 32) {
335           *result = 0;
336           return true;
337         }
338       }
339       break;
340 
341     // Bitwise operations
342     case SpvOp::SpvOpBitwiseOr:
343       for (uint32_t i = 0; i < 2; i++) {
344         if (constants[i] != nullptr) {
345           // TODO: Change the mask against a value based on the bit width of the
346           // instruction result type.  This way we can handle say 16-bit values
347           // as well.
348           uint32_t mask = constants[i]->GetU32BitValue();
349           if (mask == 0xFFFFFFFF) {
350             *result = 0xFFFFFFFF;
351             return true;
352           }
353         }
354       }
355       break;
356     case SpvOp::SpvOpBitwiseAnd:
357       for (uint32_t i = 0; i < 2; i++) {
358         if (constants[i] != nullptr) {
359           if (constants[i]->IsZero()) {
360             *result = 0;
361             return true;
362           }
363         }
364       }
365       break;
366 
367     // Comparison
368     case SpvOp::SpvOpULessThan:
369       if (constants[0] != nullptr &&
370           constants[0]->GetU32BitValue() == UINT32_MAX) {
371         *result = false;
372         return true;
373       }
374       if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
375         *result = false;
376         return true;
377       }
378       break;
379     case SpvOp::SpvOpSLessThan:
380       if (constants[0] != nullptr &&
381           constants[0]->GetS32BitValue() == INT32_MAX) {
382         *result = false;
383         return true;
384       }
385       if (constants[1] != nullptr &&
386           constants[1]->GetS32BitValue() == INT32_MIN) {
387         *result = false;
388         return true;
389       }
390       break;
391     case SpvOp::SpvOpUGreaterThan:
392       if (constants[0] != nullptr && constants[0]->IsZero()) {
393         *result = false;
394         return true;
395       }
396       if (constants[1] != nullptr &&
397           constants[1]->GetU32BitValue() == UINT32_MAX) {
398         *result = false;
399         return true;
400       }
401       break;
402     case SpvOp::SpvOpSGreaterThan:
403       if (constants[0] != nullptr &&
404           constants[0]->GetS32BitValue() == INT32_MIN) {
405         *result = false;
406         return true;
407       }
408       if (constants[1] != nullptr &&
409           constants[1]->GetS32BitValue() == INT32_MAX) {
410         *result = false;
411         return true;
412       }
413       break;
414     case SpvOp::SpvOpULessThanEqual:
415       if (constants[0] != nullptr && constants[0]->IsZero()) {
416         *result = true;
417         return true;
418       }
419       if (constants[1] != nullptr &&
420           constants[1]->GetU32BitValue() == UINT32_MAX) {
421         *result = true;
422         return true;
423       }
424       break;
425     case SpvOp::SpvOpSLessThanEqual:
426       if (constants[0] != nullptr &&
427           constants[0]->GetS32BitValue() == INT32_MIN) {
428         *result = true;
429         return true;
430       }
431       if (constants[1] != nullptr &&
432           constants[1]->GetS32BitValue() == INT32_MAX) {
433         *result = true;
434         return true;
435       }
436       break;
437     case SpvOp::SpvOpUGreaterThanEqual:
438       if (constants[0] != nullptr &&
439           constants[0]->GetU32BitValue() == UINT32_MAX) {
440         *result = true;
441         return true;
442       }
443       if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
444         *result = true;
445         return true;
446       }
447       break;
448     case SpvOp::SpvOpSGreaterThanEqual:
449       if (constants[0] != nullptr &&
450           constants[0]->GetS32BitValue() == INT32_MAX) {
451         *result = true;
452         return true;
453       }
454       if (constants[1] != nullptr &&
455           constants[1]->GetS32BitValue() == INT32_MIN) {
456         *result = true;
457         return true;
458       }
459       break;
460     default:
461       break;
462   }
463   return false;
464 }
465 
FoldBinaryBooleanOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const466 bool InstructionFolder::FoldBinaryBooleanOpToConstant(
467     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
468     uint32_t* result) const {
469   SpvOp opcode = inst->opcode();
470   analysis::ConstantManager* const_manger = context_->get_constant_mgr();
471 
472   uint32_t ids[2];
473   const analysis::BoolConstant* constants[2];
474   for (uint32_t i = 0; i < 2; i++) {
475     const Operand* operand = &inst->GetInOperand(i);
476     if (operand->type != SPV_OPERAND_TYPE_ID) {
477       return false;
478     }
479     ids[i] = id_map(operand->words[0]);
480     const analysis::Constant* constant =
481         const_manger->FindDeclaredConstant(ids[i]);
482     constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
483   }
484 
485   switch (opcode) {
486     // Logical
487     case SpvOp::SpvOpLogicalOr:
488       for (uint32_t i = 0; i < 2; i++) {
489         if (constants[i] != nullptr) {
490           if (constants[i]->value()) {
491             *result = true;
492             return true;
493           }
494         }
495       }
496       break;
497     case SpvOp::SpvOpLogicalAnd:
498       for (uint32_t i = 0; i < 2; i++) {
499         if (constants[i] != nullptr) {
500           if (!constants[i]->value()) {
501             *result = false;
502             return true;
503           }
504         }
505       }
506       break;
507 
508     default:
509       break;
510   }
511   return false;
512 }
513 
FoldIntegerOpToConstant(Instruction * inst,const std::function<uint32_t (uint32_t)> & id_map,uint32_t * result) const514 bool InstructionFolder::FoldIntegerOpToConstant(
515     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
516     uint32_t* result) const {
517   assert(IsFoldableOpcode(inst->opcode()) &&
518          "Unhandled instruction opcode in FoldScalars");
519   switch (inst->NumInOperands()) {
520     case 2:
521       return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
522              FoldBinaryBooleanOpToConstant(inst, id_map, result);
523     default:
524       return false;
525   }
526 }
527 
FoldVectors(SpvOp opcode,uint32_t num_dims,const std::vector<const analysis::Constant * > & operands) const528 std::vector<uint32_t> InstructionFolder::FoldVectors(
529     SpvOp opcode, uint32_t num_dims,
530     const std::vector<const analysis::Constant*>& operands) const {
531   assert(IsFoldableOpcode(opcode) &&
532          "Unhandled instruction opcode in FoldVectors");
533   std::vector<uint32_t> result;
534   for (uint32_t d = 0; d < num_dims; d++) {
535     std::vector<uint32_t> operand_values_for_one_dimension;
536     for (const auto& operand : operands) {
537       if (const analysis::VectorConstant* vector_operand =
538               operand->AsVectorConstant()) {
539         // Extract the raw value of the scalar component constants
540         // in 32-bit words here. The reason of not using FoldScalars() here
541         // is that we do not create temporary null constants as components
542         // when the vector operand is a NullConstant because Constant creation
543         // may need extra checks for the validity and that is not manageed in
544         // here.
545         if (const analysis::ScalarConstant* scalar_component =
546                 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
547           const auto& scalar_words = scalar_component->words();
548           assert(
549               scalar_words.size() == 1 &&
550               "Vector components with longer than 32-bit width are not allowed "
551               "in FoldVectors()");
552           operand_values_for_one_dimension.push_back(scalar_words.front());
553         } else if (operand->AsNullConstant()) {
554           operand_values_for_one_dimension.push_back(0u);
555         } else {
556           assert(false &&
557                  "VectorConst should only has ScalarConst or NullConst as "
558                  "components");
559         }
560       } else if (operand->AsNullConstant()) {
561         operand_values_for_one_dimension.push_back(0u);
562       } else {
563         assert(false &&
564                "FoldVectors() only accepts VectorConst or NullConst type of "
565                "constant");
566       }
567     }
568     result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
569   }
570   return result;
571 }
572 
IsFoldableOpcode(SpvOp opcode) const573 bool InstructionFolder::IsFoldableOpcode(SpvOp opcode) const {
574   // NOTE: Extend to more opcodes as new cases are handled in the folder
575   // functions.
576   switch (opcode) {
577     case SpvOp::SpvOpBitwiseAnd:
578     case SpvOp::SpvOpBitwiseOr:
579     case SpvOp::SpvOpBitwiseXor:
580     case SpvOp::SpvOpIAdd:
581     case SpvOp::SpvOpIEqual:
582     case SpvOp::SpvOpIMul:
583     case SpvOp::SpvOpINotEqual:
584     case SpvOp::SpvOpISub:
585     case SpvOp::SpvOpLogicalAnd:
586     case SpvOp::SpvOpLogicalEqual:
587     case SpvOp::SpvOpLogicalNot:
588     case SpvOp::SpvOpLogicalNotEqual:
589     case SpvOp::SpvOpLogicalOr:
590     case SpvOp::SpvOpNot:
591     case SpvOp::SpvOpSDiv:
592     case SpvOp::SpvOpSelect:
593     case SpvOp::SpvOpSGreaterThan:
594     case SpvOp::SpvOpSGreaterThanEqual:
595     case SpvOp::SpvOpShiftLeftLogical:
596     case SpvOp::SpvOpShiftRightArithmetic:
597     case SpvOp::SpvOpShiftRightLogical:
598     case SpvOp::SpvOpSLessThan:
599     case SpvOp::SpvOpSLessThanEqual:
600     case SpvOp::SpvOpSMod:
601     case SpvOp::SpvOpSNegate:
602     case SpvOp::SpvOpSRem:
603     case SpvOp::SpvOpSConvert:
604     case SpvOp::SpvOpUConvert:
605     case SpvOp::SpvOpUDiv:
606     case SpvOp::SpvOpUGreaterThan:
607     case SpvOp::SpvOpUGreaterThanEqual:
608     case SpvOp::SpvOpULessThan:
609     case SpvOp::SpvOpULessThanEqual:
610     case SpvOp::SpvOpUMod:
611       return true;
612     default:
613       return false;
614   }
615 }
616 
IsFoldableConstant(const analysis::Constant * cst) const617 bool InstructionFolder::IsFoldableConstant(
618     const analysis::Constant* cst) const {
619   // Currently supported constants are 32-bit values or null constants.
620   if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
621     return scalar->words().size() == 1;
622   else
623     return cst->AsNullConstant() != nullptr;
624 }
625 
FoldInstructionToConstant(Instruction * inst,std::function<uint32_t (uint32_t)> id_map) const626 Instruction* InstructionFolder::FoldInstructionToConstant(
627     Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
628   analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
629 
630   if (!inst->IsFoldableByFoldScalar() &&
631       !GetConstantFoldingRules().HasFoldingRule(inst)) {
632     return nullptr;
633   }
634   // Collect the values of the constant parameters.
635   std::vector<const analysis::Constant*> constants;
636   bool missing_constants = false;
637   inst->ForEachInId([&constants, &missing_constants, const_mgr,
638                      &id_map](uint32_t* op_id) {
639     uint32_t id = id_map(*op_id);
640     const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
641     if (!const_op) {
642       constants.push_back(nullptr);
643       missing_constants = true;
644     } else {
645       constants.push_back(const_op);
646     }
647   });
648 
649   const analysis::Constant* folded_const = nullptr;
650   for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) {
651     folded_const = rule(context_, inst, constants);
652     if (folded_const != nullptr) {
653       Instruction* const_inst =
654           const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
655       if (const_inst == nullptr) {
656         return nullptr;
657       }
658       assert(const_inst->type_id() == inst->type_id());
659       // May be a new instruction that needs to be analysed.
660       context_->UpdateDefUse(const_inst);
661       return const_inst;
662     }
663   }
664 
665   uint32_t result_val = 0;
666   bool successful = false;
667   // If all parameters are constant, fold the instruction to a constant.
668   if (!missing_constants && inst->IsFoldableByFoldScalar()) {
669     result_val = FoldScalars(inst->opcode(), constants);
670     successful = true;
671   }
672 
673   if (!successful && inst->IsFoldableByFoldScalar()) {
674     successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
675   }
676 
677   if (successful) {
678     const analysis::Constant* result_const =
679         const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
680     Instruction* folded_inst =
681         const_mgr->GetDefiningInstruction(result_const, inst->type_id());
682     return folded_inst;
683   }
684   return nullptr;
685 }
686 
IsFoldableType(Instruction * type_inst) const687 bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
688   // Support 32-bit integers.
689   if (type_inst->opcode() == SpvOpTypeInt) {
690     return type_inst->GetSingleWordInOperand(0) == 32;
691   }
692   // Support booleans.
693   if (type_inst->opcode() == SpvOpTypeBool) {
694     return true;
695   }
696   // Nothing else yet.
697   return false;
698 }
699 
FoldInstruction(Instruction * inst) const700 bool InstructionFolder::FoldInstruction(Instruction* inst) const {
701   bool modified = false;
702   Instruction* folded_inst(inst);
703   while (folded_inst->opcode() != SpvOpCopyObject &&
704          FoldInstructionInternal(&*folded_inst)) {
705     modified = true;
706   }
707   return modified;
708 }
709 
710 }  // namespace opt
711 }  // namespace spvtools
712