1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/const_folding_rules.h"
16 
17 #include "source/opt/ir_context.h"
18 
19 namespace spvtools {
20 namespace opt {
21 namespace {
22 
23 const uint32_t kExtractCompositeIdInIdx = 0;
24 
25 // Returns true if |type| is Float or a vector of Float.
HasFloatingPoint(const analysis::Type * type)26 bool HasFloatingPoint(const analysis::Type* type) {
27   if (type->AsFloat()) {
28     return true;
29   } else if (const analysis::Vector* vec_type = type->AsVector()) {
30     return vec_type->element_type()->AsFloat() != nullptr;
31   }
32 
33   return false;
34 }
35 
36 // Folds an OpcompositeExtract where input is a composite constant.
FoldExtractWithConstants()37 ConstantFoldingRule FoldExtractWithConstants() {
38   return [](IRContext* context, Instruction* inst,
39             const std::vector<const analysis::Constant*>& constants)
40              -> const analysis::Constant* {
41     const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
42     if (c == nullptr) {
43       return nullptr;
44     }
45 
46     for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
47       uint32_t element_index = inst->GetSingleWordInOperand(i);
48       if (c->AsNullConstant()) {
49         // Return Null for the return type.
50         analysis::ConstantManager* const_mgr = context->get_constant_mgr();
51         analysis::TypeManager* type_mgr = context->get_type_mgr();
52         return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
53       }
54 
55       auto cc = c->AsCompositeConstant();
56       assert(cc != nullptr);
57       auto components = cc->GetComponents();
58       // Protect against invalid IR.  Refuse to fold if the index is out
59       // of bounds.
60       if (element_index >= components.size()) return nullptr;
61       c = components[element_index];
62     }
63     return c;
64   };
65 }
66 
FoldVectorShuffleWithConstants()67 ConstantFoldingRule FoldVectorShuffleWithConstants() {
68   return [](IRContext* context, Instruction* inst,
69             const std::vector<const analysis::Constant*>& constants)
70              -> const analysis::Constant* {
71     assert(inst->opcode() == SpvOpVectorShuffle);
72     const analysis::Constant* c1 = constants[0];
73     const analysis::Constant* c2 = constants[1];
74     if (c1 == nullptr || c2 == nullptr) {
75       return nullptr;
76     }
77 
78     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
79     const analysis::Type* element_type = c1->type()->AsVector()->element_type();
80 
81     std::vector<const analysis::Constant*> c1_components;
82     if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
83       c1_components = vec_const->GetComponents();
84     } else {
85       assert(c1->AsNullConstant());
86       const analysis::Constant* element =
87           const_mgr->GetConstant(element_type, {});
88       c1_components.resize(c1->type()->AsVector()->element_count(), element);
89     }
90     std::vector<const analysis::Constant*> c2_components;
91     if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
92       c2_components = vec_const->GetComponents();
93     } else {
94       assert(c2->AsNullConstant());
95       const analysis::Constant* element =
96           const_mgr->GetConstant(element_type, {});
97       c2_components.resize(c2->type()->AsVector()->element_count(), element);
98     }
99 
100     std::vector<uint32_t> ids;
101     const uint32_t undef_literal_value = 0xffffffff;
102     for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
103       uint32_t index = inst->GetSingleWordInOperand(i);
104       if (index == undef_literal_value) {
105         // Don't fold shuffle with undef literal value.
106         return nullptr;
107       } else if (index < c1_components.size()) {
108         Instruction* member_inst =
109             const_mgr->GetDefiningInstruction(c1_components[index]);
110         ids.push_back(member_inst->result_id());
111       } else {
112         Instruction* member_inst = const_mgr->GetDefiningInstruction(
113             c2_components[index - c1_components.size()]);
114         ids.push_back(member_inst->result_id());
115       }
116     }
117 
118     analysis::TypeManager* type_mgr = context->get_type_mgr();
119     return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
120   };
121 }
122 
FoldVectorTimesScalar()123 ConstantFoldingRule FoldVectorTimesScalar() {
124   return [](IRContext* context, Instruction* inst,
125             const std::vector<const analysis::Constant*>& constants)
126              -> const analysis::Constant* {
127     assert(inst->opcode() == SpvOpVectorTimesScalar);
128     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
129     analysis::TypeManager* type_mgr = context->get_type_mgr();
130 
131     if (!inst->IsFloatingPointFoldingAllowed()) {
132       if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
133         return nullptr;
134       }
135     }
136 
137     const analysis::Constant* c1 = constants[0];
138     const analysis::Constant* c2 = constants[1];
139 
140     if (c1 && c1->IsZero()) {
141       return c1;
142     }
143 
144     if (c2 && c2->IsZero()) {
145       // Get or create the NullConstant for this type.
146       std::vector<uint32_t> ids;
147       return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
148     }
149 
150     if (c1 == nullptr || c2 == nullptr) {
151       return nullptr;
152     }
153 
154     // Check result type.
155     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
156     const analysis::Vector* vector_type = result_type->AsVector();
157     assert(vector_type != nullptr);
158     const analysis::Type* element_type = vector_type->element_type();
159     assert(element_type != nullptr);
160     const analysis::Float* float_type = element_type->AsFloat();
161     assert(float_type != nullptr);
162 
163     // Check types of c1 and c2.
164     assert(c1->type()->AsVector() == vector_type);
165     assert(c1->type()->AsVector()->element_type() == element_type &&
166            c2->type() == element_type);
167 
168     // Get a float vector that is the result of vector-times-scalar.
169     std::vector<const analysis::Constant*> c1_components =
170         c1->GetVectorComponents(const_mgr);
171     std::vector<uint32_t> ids;
172     if (float_type->width() == 32) {
173       float scalar = c2->GetFloat();
174       for (uint32_t i = 0; i < c1_components.size(); ++i) {
175         utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
176         std::vector<uint32_t> words = result.GetWords();
177         const analysis::Constant* new_elem =
178             const_mgr->GetConstant(float_type, words);
179         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
180       }
181       return const_mgr->GetConstant(vector_type, ids);
182     } else if (float_type->width() == 64) {
183       double scalar = c2->GetDouble();
184       for (uint32_t i = 0; i < c1_components.size(); ++i) {
185         utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
186                                          scalar);
187         std::vector<uint32_t> words = result.GetWords();
188         const analysis::Constant* new_elem =
189             const_mgr->GetConstant(float_type, words);
190         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
191       }
192       return const_mgr->GetConstant(vector_type, ids);
193     }
194     return nullptr;
195   };
196 }
197 
FoldCompositeWithConstants()198 ConstantFoldingRule FoldCompositeWithConstants() {
199   // Folds an OpCompositeConstruct where all of the inputs are constants to a
200   // constant.  A new constant is created if necessary.
201   return [](IRContext* context, Instruction* inst,
202             const std::vector<const analysis::Constant*>& constants)
203              -> const analysis::Constant* {
204     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
205     analysis::TypeManager* type_mgr = context->get_type_mgr();
206     const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
207     Instruction* type_inst =
208         context->get_def_use_mgr()->GetDef(inst->type_id());
209 
210     std::vector<uint32_t> ids;
211     for (uint32_t i = 0; i < constants.size(); ++i) {
212       const analysis::Constant* element_const = constants[i];
213       if (element_const == nullptr) {
214         return nullptr;
215       }
216 
217       uint32_t component_type_id = 0;
218       if (type_inst->opcode() == SpvOpTypeStruct) {
219         component_type_id = type_inst->GetSingleWordInOperand(i);
220       } else if (type_inst->opcode() == SpvOpTypeArray) {
221         component_type_id = type_inst->GetSingleWordInOperand(0);
222       }
223 
224       uint32_t element_id =
225           const_mgr->FindDeclaredConstant(element_const, component_type_id);
226       if (element_id == 0) {
227         return nullptr;
228       }
229       ids.push_back(element_id);
230     }
231     return const_mgr->GetConstant(new_type, ids);
232   };
233 }
234 
235 // The interface for a function that returns the result of applying a scalar
236 // floating-point binary operation on |a| and |b|.  The type of the return value
237 // will be |type|.  The input constants must also be of type |type|.
238 using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
239     const analysis::Type* result_type, const analysis::Constant* a,
240     analysis::ConstantManager*)>;
241 
242 // The interface for a function that returns the result of applying a scalar
243 // floating-point binary operation on |a| and |b|.  The type of the return value
244 // will be |type|.  The input constants must also be of type |type|.
245 using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
246     const analysis::Type* result_type, const analysis::Constant* a,
247     const analysis::Constant* b, analysis::ConstantManager*)>;
248 
249 // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
250 // using |scalar_rule| and unary float point vectors ops by applying
251 // |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
252 // that is returned assumes that |constants| contains 1 entry.  If they are
253 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
254 // whose element type is |Float| or |Integer|.
FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule)255 ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
256   return [scalar_rule](IRContext* context, Instruction* inst,
257                        const std::vector<const analysis::Constant*>& constants)
258              -> const analysis::Constant* {
259     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
260     analysis::TypeManager* type_mgr = context->get_type_mgr();
261     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
262     const analysis::Vector* vector_type = result_type->AsVector();
263 
264     if (!inst->IsFloatingPointFoldingAllowed()) {
265       return nullptr;
266     }
267 
268     const analysis::Constant* arg =
269         (inst->opcode() == SpvOpExtInst) ? constants[1] : constants[0];
270 
271     if (arg == nullptr) {
272       return nullptr;
273     }
274 
275     if (vector_type != nullptr) {
276       std::vector<const analysis::Constant*> a_components;
277       std::vector<const analysis::Constant*> results_components;
278 
279       a_components = arg->GetVectorComponents(const_mgr);
280 
281       // Fold each component of the vector.
282       for (uint32_t i = 0; i < a_components.size(); ++i) {
283         results_components.push_back(scalar_rule(vector_type->element_type(),
284                                                  a_components[i], const_mgr));
285         if (results_components[i] == nullptr) {
286           return nullptr;
287         }
288       }
289 
290       // Build the constant object and return it.
291       std::vector<uint32_t> ids;
292       for (const analysis::Constant* member : results_components) {
293         ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
294       }
295       return const_mgr->GetConstant(vector_type, ids);
296     } else {
297       return scalar_rule(result_type, arg, const_mgr);
298     }
299   };
300 }
301 
302 // Returns the result of folding the constants in |constants| according the
303 // |scalar_rule|.  If |result_type| is a vector, then |scalar_rule| is applied
304 // per component.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule,uint32_t result_type_id,const std::vector<const analysis::Constant * > & constants,IRContext * context)305 const analysis::Constant* FoldFPBinaryOp(
306     BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
307     const std::vector<const analysis::Constant*>& constants,
308     IRContext* context) {
309   analysis::ConstantManager* const_mgr = context->get_constant_mgr();
310   analysis::TypeManager* type_mgr = context->get_type_mgr();
311   const analysis::Type* result_type = type_mgr->GetType(result_type_id);
312   const analysis::Vector* vector_type = result_type->AsVector();
313 
314   if (constants[0] == nullptr || constants[1] == nullptr) {
315     return nullptr;
316   }
317 
318   if (vector_type != nullptr) {
319     std::vector<const analysis::Constant*> a_components;
320     std::vector<const analysis::Constant*> b_components;
321     std::vector<const analysis::Constant*> results_components;
322 
323     a_components = constants[0]->GetVectorComponents(const_mgr);
324     b_components = constants[1]->GetVectorComponents(const_mgr);
325 
326     // Fold each component of the vector.
327     for (uint32_t i = 0; i < a_components.size(); ++i) {
328       results_components.push_back(scalar_rule(vector_type->element_type(),
329                                                a_components[i], b_components[i],
330                                                const_mgr));
331       if (results_components[i] == nullptr) {
332         return nullptr;
333       }
334     }
335 
336     // Build the constant object and return it.
337     std::vector<uint32_t> ids;
338     for (const analysis::Constant* member : results_components) {
339       ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
340     }
341     return const_mgr->GetConstant(vector_type, ids);
342   } else {
343     return scalar_rule(result_type, constants[0], constants[1], const_mgr);
344   }
345 }
346 
347 // Returns a |ConstantFoldingRule| that folds floating point scalars using
348 // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
349 // elements of the vector.  The |ConstantFoldingRule| that is returned assumes
350 // that |constants| contains 2 entries.  If they are not |nullptr|, then their
351 // type is either |Float| or a |Vector| whose element type is |Float|.
FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule)352 ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
353   return [scalar_rule](IRContext* context, Instruction* inst,
354                        const std::vector<const analysis::Constant*>& constants)
355              -> const analysis::Constant* {
356     if (!inst->IsFloatingPointFoldingAllowed()) {
357       return nullptr;
358     }
359     if (inst->opcode() == SpvOpExtInst) {
360       return FoldFPBinaryOp(scalar_rule, inst->type_id(),
361                             {constants[1], constants[2]}, context);
362     }
363     return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
364   };
365 }
366 
367 // This macro defines a |UnaryScalarFoldingRule| that performs float to
368 // integer conversion.
369 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldFToIOp()370 UnaryScalarFoldingRule FoldFToIOp() {
371   return [](const analysis::Type* result_type, const analysis::Constant* a,
372             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
373     assert(result_type != nullptr && a != nullptr);
374     const analysis::Integer* integer_type = result_type->AsInteger();
375     const analysis::Float* float_type = a->type()->AsFloat();
376     assert(float_type != nullptr);
377     assert(integer_type != nullptr);
378     if (integer_type->width() != 32) return nullptr;
379     if (float_type->width() == 32) {
380       float fa = a->GetFloat();
381       uint32_t result = integer_type->IsSigned()
382                             ? static_cast<uint32_t>(static_cast<int32_t>(fa))
383                             : static_cast<uint32_t>(fa);
384       std::vector<uint32_t> words = {result};
385       return const_mgr->GetConstant(result_type, words);
386     } else if (float_type->width() == 64) {
387       double fa = a->GetDouble();
388       uint32_t result = integer_type->IsSigned()
389                             ? static_cast<uint32_t>(static_cast<int32_t>(fa))
390                             : static_cast<uint32_t>(fa);
391       std::vector<uint32_t> words = {result};
392       return const_mgr->GetConstant(result_type, words);
393     }
394     return nullptr;
395   };
396 }
397 
398 // This function defines a |UnaryScalarFoldingRule| that performs integer to
399 // float conversion.
400 // TODO(greg-lunarg): Support for 64-bit integer types.
FoldIToFOp()401 UnaryScalarFoldingRule FoldIToFOp() {
402   return [](const analysis::Type* result_type, const analysis::Constant* a,
403             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
404     assert(result_type != nullptr && a != nullptr);
405     const analysis::Integer* integer_type = a->type()->AsInteger();
406     const analysis::Float* float_type = result_type->AsFloat();
407     assert(float_type != nullptr);
408     assert(integer_type != nullptr);
409     if (integer_type->width() != 32) return nullptr;
410     uint32_t ua = a->GetU32();
411     if (float_type->width() == 32) {
412       float result_val = integer_type->IsSigned()
413                              ? static_cast<float>(static_cast<int32_t>(ua))
414                              : static_cast<float>(ua);
415       utils::FloatProxy<float> result(result_val);
416       std::vector<uint32_t> words = {result.data()};
417       return const_mgr->GetConstant(result_type, words);
418     } else if (float_type->width() == 64) {
419       double result_val = integer_type->IsSigned()
420                               ? static_cast<double>(static_cast<int32_t>(ua))
421                               : static_cast<double>(ua);
422       utils::FloatProxy<double> result(result_val);
423       std::vector<uint32_t> words = result.GetWords();
424       return const_mgr->GetConstant(result_type, words);
425     }
426     return nullptr;
427   };
428 }
429 
430 // This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
FoldQuantizeToF16Scalar()431 UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
432   return [](const analysis::Type* result_type, const analysis::Constant* a,
433             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
434     assert(result_type != nullptr && a != nullptr);
435     const analysis::Float* float_type = a->type()->AsFloat();
436     assert(float_type != nullptr);
437     if (float_type->width() != 32) {
438       return nullptr;
439     }
440 
441     float fa = a->GetFloat();
442     utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
443     utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
444     utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
445     orignal.castTo(quantized, utils::round_direction::kToZero);
446     quantized.castTo(result, utils::round_direction::kToZero);
447     std::vector<uint32_t> words = {result.getBits()};
448     return const_mgr->GetConstant(result_type, words);
449   };
450 }
451 
452 // This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
453 // operator |op| must work for both float and double, and use syntax "f1 op f2".
454 #define FOLD_FPARITH_OP(op)                                                   \
455   [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
456      const analysis::Constant* b,                                             \
457      analysis::ConstantManager* const_mgr_in_macro)                           \
458       -> const analysis::Constant* {                                          \
459     assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr);  \
460     assert(result_type_in_macro == a->type() &&                               \
461            result_type_in_macro == b->type());                                \
462     const analysis::Float* float_type_in_macro =                              \
463         result_type_in_macro->AsFloat();                                      \
464     assert(float_type_in_macro != nullptr);                                   \
465     if (float_type_in_macro->width() == 32) {                                 \
466       float fa = a->GetFloat();                                               \
467       float fb = b->GetFloat();                                               \
468       utils::FloatProxy<float> result_in_macro(fa op fb);                     \
469       std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
470       return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
471                                              words_in_macro);                 \
472     } else if (float_type_in_macro->width() == 64) {                          \
473       double fa = a->GetDouble();                                             \
474       double fb = b->GetDouble();                                             \
475       utils::FloatProxy<double> result_in_macro(fa op fb);                    \
476       std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
477       return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
478                                              words_in_macro);                 \
479     }                                                                         \
480     return nullptr;                                                           \
481   }
482 
483 // Define the folding rule for conversion between floating point and integer
FoldFToI()484 ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
FoldIToF()485 ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
FoldQuantizeToF16()486 ConstantFoldingRule FoldQuantizeToF16() {
487   return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
488 }
489 
490 // Define the folding rules for subtraction, addition, multiplication, and
491 // division for floating point values.
FoldFSub()492 ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
FoldFAdd()493 ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
FoldFMul()494 ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
FoldFDiv()495 ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
496 
CompareFloatingPoint(bool op_result,bool op_unordered,bool need_ordered)497 bool CompareFloatingPoint(bool op_result, bool op_unordered,
498                           bool need_ordered) {
499   if (need_ordered) {
500     // operands are ordered and Operand 1 is |op| Operand 2
501     return !op_unordered && op_result;
502   } else {
503     // operands are unordered or Operand 1 is |op| Operand 2
504     return op_unordered || op_result;
505   }
506 }
507 
508 // This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
509 // operator |op| must work for both float and double, and use syntax "f1 op f2".
510 #define FOLD_FPCMP_OP(op, ord)                                            \
511   [](const analysis::Type* result_type, const analysis::Constant* a,      \
512      const analysis::Constant* b,                                         \
513      analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
514     assert(result_type != nullptr && a != nullptr && b != nullptr);       \
515     assert(result_type->AsBool());                                        \
516     assert(a->type() == b->type());                                       \
517     const analysis::Float* float_type = a->type()->AsFloat();             \
518     assert(float_type != nullptr);                                        \
519     if (float_type->width() == 32) {                                      \
520       float fa = a->GetFloat();                                           \
521       float fb = b->GetFloat();                                           \
522       bool result = CompareFloatingPoint(                                 \
523           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
524       std::vector<uint32_t> words = {uint32_t(result)};                   \
525       return const_mgr->GetConstant(result_type, words);                  \
526     } else if (float_type->width() == 64) {                               \
527       double fa = a->GetDouble();                                         \
528       double fb = b->GetDouble();                                         \
529       bool result = CompareFloatingPoint(                                 \
530           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
531       std::vector<uint32_t> words = {uint32_t(result)};                   \
532       return const_mgr->GetConstant(result_type, words);                  \
533     }                                                                     \
534     return nullptr;                                                       \
535   }
536 
537 // Define the folding rules for ordered and unordered comparison for floating
538 // point values.
FoldFOrdEqual()539 ConstantFoldingRule FoldFOrdEqual() {
540   return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
541 }
FoldFUnordEqual()542 ConstantFoldingRule FoldFUnordEqual() {
543   return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
544 }
FoldFOrdNotEqual()545 ConstantFoldingRule FoldFOrdNotEqual() {
546   return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
547 }
FoldFUnordNotEqual()548 ConstantFoldingRule FoldFUnordNotEqual() {
549   return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
550 }
FoldFOrdLessThan()551 ConstantFoldingRule FoldFOrdLessThan() {
552   return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
553 }
FoldFUnordLessThan()554 ConstantFoldingRule FoldFUnordLessThan() {
555   return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
556 }
FoldFOrdGreaterThan()557 ConstantFoldingRule FoldFOrdGreaterThan() {
558   return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
559 }
FoldFUnordGreaterThan()560 ConstantFoldingRule FoldFUnordGreaterThan() {
561   return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
562 }
FoldFOrdLessThanEqual()563 ConstantFoldingRule FoldFOrdLessThanEqual() {
564   return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
565 }
FoldFUnordLessThanEqual()566 ConstantFoldingRule FoldFUnordLessThanEqual() {
567   return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
568 }
FoldFOrdGreaterThanEqual()569 ConstantFoldingRule FoldFOrdGreaterThanEqual() {
570   return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
571 }
FoldFUnordGreaterThanEqual()572 ConstantFoldingRule FoldFUnordGreaterThanEqual() {
573   return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
574 }
575 
576 // Folds an OpDot where all of the inputs are constants to a
577 // constant.  A new constant is created if necessary.
FoldOpDotWithConstants()578 ConstantFoldingRule FoldOpDotWithConstants() {
579   return [](IRContext* context, Instruction* inst,
580             const std::vector<const analysis::Constant*>& constants)
581              -> const analysis::Constant* {
582     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
583     analysis::TypeManager* type_mgr = context->get_type_mgr();
584     const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
585     assert(new_type->AsFloat() && "OpDot should have a float return type.");
586     const analysis::Float* float_type = new_type->AsFloat();
587 
588     if (!inst->IsFloatingPointFoldingAllowed()) {
589       return nullptr;
590     }
591 
592     // If one of the operands is 0, then the result is 0.
593     bool has_zero_operand = false;
594 
595     for (int i = 0; i < 2; ++i) {
596       if (constants[i]) {
597         if (constants[i]->AsNullConstant() ||
598             constants[i]->AsVectorConstant()->IsZero()) {
599           has_zero_operand = true;
600           break;
601         }
602       }
603     }
604 
605     if (has_zero_operand) {
606       if (float_type->width() == 32) {
607         utils::FloatProxy<float> result(0.0f);
608         std::vector<uint32_t> words = result.GetWords();
609         return const_mgr->GetConstant(float_type, words);
610       }
611       if (float_type->width() == 64) {
612         utils::FloatProxy<double> result(0.0);
613         std::vector<uint32_t> words = result.GetWords();
614         return const_mgr->GetConstant(float_type, words);
615       }
616       return nullptr;
617     }
618 
619     if (constants[0] == nullptr || constants[1] == nullptr) {
620       return nullptr;
621     }
622 
623     std::vector<const analysis::Constant*> a_components;
624     std::vector<const analysis::Constant*> b_components;
625 
626     a_components = constants[0]->GetVectorComponents(const_mgr);
627     b_components = constants[1]->GetVectorComponents(const_mgr);
628 
629     utils::FloatProxy<double> result(0.0);
630     std::vector<uint32_t> words = result.GetWords();
631     const analysis::Constant* result_const =
632         const_mgr->GetConstant(float_type, words);
633     for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
634          ++i) {
635       if (a_components[i] == nullptr || b_components[i] == nullptr) {
636         return nullptr;
637       }
638 
639       const analysis::Constant* component = FOLD_FPARITH_OP(*)(
640           new_type, a_components[i], b_components[i], const_mgr);
641       if (component == nullptr) {
642         return nullptr;
643       }
644       result_const =
645           FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
646     }
647     return result_const;
648   };
649 }
650 
651 // This function defines a |UnaryScalarFoldingRule| that subtracts the constant
652 // from zero.
FoldFNegateOp()653 UnaryScalarFoldingRule FoldFNegateOp() {
654   return [](const analysis::Type* result_type, const analysis::Constant* a,
655             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
656     assert(result_type != nullptr && a != nullptr);
657     assert(result_type == a->type());
658     const analysis::Float* float_type = result_type->AsFloat();
659     assert(float_type != nullptr);
660     if (float_type->width() == 32) {
661       float fa = a->GetFloat();
662       utils::FloatProxy<float> result(-fa);
663       std::vector<uint32_t> words = result.GetWords();
664       return const_mgr->GetConstant(result_type, words);
665     } else if (float_type->width() == 64) {
666       double da = a->GetDouble();
667       utils::FloatProxy<double> result(-da);
668       std::vector<uint32_t> words = result.GetWords();
669       return const_mgr->GetConstant(result_type, words);
670     }
671     return nullptr;
672   };
673 }
674 
FoldFNegate()675 ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
676 
FoldFClampFeedingCompare(uint32_t cmp_opcode)677 ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
678   return [cmp_opcode](IRContext* context, Instruction* inst,
679                       const std::vector<const analysis::Constant*>& constants)
680              -> const analysis::Constant* {
681     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
682     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
683 
684     if (!inst->IsFloatingPointFoldingAllowed()) {
685       return nullptr;
686     }
687 
688     uint32_t non_const_idx = (constants[0] ? 1 : 0);
689     uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
690     Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
691 
692     analysis::TypeManager* type_mgr = context->get_type_mgr();
693     const analysis::Type* operand_type =
694         type_mgr->GetType(operand_inst->type_id());
695 
696     if (!operand_type->AsFloat()) {
697       return nullptr;
698     }
699 
700     if (operand_type->AsFloat()->width() != 32 &&
701         operand_type->AsFloat()->width() != 64) {
702       return nullptr;
703     }
704 
705     if (operand_inst->opcode() != SpvOpExtInst) {
706       return nullptr;
707     }
708 
709     if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
710       return nullptr;
711     }
712 
713     if (constants[1] == nullptr && constants[0] == nullptr) {
714       return nullptr;
715     }
716 
717     uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
718     const analysis::Constant* max_const =
719         const_mgr->FindDeclaredConstant(max_id);
720 
721     uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
722     const analysis::Constant* min_const =
723         const_mgr->FindDeclaredConstant(min_id);
724 
725     bool found_result = false;
726     bool result = false;
727 
728     switch (cmp_opcode) {
729       case SpvOpFOrdLessThan:
730       case SpvOpFUnordLessThan:
731       case SpvOpFOrdGreaterThanEqual:
732       case SpvOpFUnordGreaterThanEqual:
733         if (constants[0]) {
734           if (min_const) {
735             if (constants[0]->GetValueAsDouble() <
736                 min_const->GetValueAsDouble()) {
737               found_result = true;
738               result = (cmp_opcode == SpvOpFOrdLessThan ||
739                         cmp_opcode == SpvOpFUnordLessThan);
740             }
741           }
742           if (max_const) {
743             if (constants[0]->GetValueAsDouble() >=
744                 max_const->GetValueAsDouble()) {
745               found_result = true;
746               result = !(cmp_opcode == SpvOpFOrdLessThan ||
747                          cmp_opcode == SpvOpFUnordLessThan);
748             }
749           }
750         }
751 
752         if (constants[1]) {
753           if (max_const) {
754             if (max_const->GetValueAsDouble() <
755                 constants[1]->GetValueAsDouble()) {
756               found_result = true;
757               result = (cmp_opcode == SpvOpFOrdLessThan ||
758                         cmp_opcode == SpvOpFUnordLessThan);
759             }
760           }
761 
762           if (min_const) {
763             if (min_const->GetValueAsDouble() >=
764                 constants[1]->GetValueAsDouble()) {
765               found_result = true;
766               result = !(cmp_opcode == SpvOpFOrdLessThan ||
767                          cmp_opcode == SpvOpFUnordLessThan);
768             }
769           }
770         }
771         break;
772       case SpvOpFOrdGreaterThan:
773       case SpvOpFUnordGreaterThan:
774       case SpvOpFOrdLessThanEqual:
775       case SpvOpFUnordLessThanEqual:
776         if (constants[0]) {
777           if (min_const) {
778             if (constants[0]->GetValueAsDouble() <=
779                 min_const->GetValueAsDouble()) {
780               found_result = true;
781               result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
782                         cmp_opcode == SpvOpFUnordLessThanEqual);
783             }
784           }
785           if (max_const) {
786             if (constants[0]->GetValueAsDouble() >
787                 max_const->GetValueAsDouble()) {
788               found_result = true;
789               result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
790                          cmp_opcode == SpvOpFUnordLessThanEqual);
791             }
792           }
793         }
794 
795         if (constants[1]) {
796           if (max_const) {
797             if (max_const->GetValueAsDouble() <=
798                 constants[1]->GetValueAsDouble()) {
799               found_result = true;
800               result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
801                         cmp_opcode == SpvOpFUnordLessThanEqual);
802             }
803           }
804 
805           if (min_const) {
806             if (min_const->GetValueAsDouble() >
807                 constants[1]->GetValueAsDouble()) {
808               found_result = true;
809               result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
810                          cmp_opcode == SpvOpFUnordLessThanEqual);
811             }
812           }
813         }
814         break;
815       default:
816         return nullptr;
817     }
818 
819     if (!found_result) {
820       return nullptr;
821     }
822 
823     const analysis::Type* bool_type =
824         context->get_type_mgr()->GetType(inst->type_id());
825     const analysis::Constant* result_const =
826         const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
827     assert(result_const);
828     return result_const;
829   };
830 }
831 
FoldFMix()832 ConstantFoldingRule FoldFMix() {
833   return [](IRContext* context, Instruction* inst,
834             const std::vector<const analysis::Constant*>& constants)
835              -> const analysis::Constant* {
836     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
837     assert(inst->opcode() == SpvOpExtInst &&
838            "Expecting an extended instruction.");
839     assert(inst->GetSingleWordInOperand(0) ==
840                context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
841            "Expecting a GLSLstd450 extended instruction.");
842     assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
843            "Expecting and FMix instruction.");
844 
845     if (!inst->IsFloatingPointFoldingAllowed()) {
846       return nullptr;
847     }
848 
849     // Make sure all FMix operands are constants.
850     for (uint32_t i = 1; i < 4; i++) {
851       if (constants[i] == nullptr) {
852         return nullptr;
853       }
854     }
855 
856     const analysis::Constant* one;
857     bool is_vector = false;
858     const analysis::Type* result_type = constants[1]->type();
859     const analysis::Type* base_type = result_type;
860     if (base_type->AsVector()) {
861       is_vector = true;
862       base_type = base_type->AsVector()->element_type();
863     }
864     assert(base_type->AsFloat() != nullptr &&
865            "FMix is suppose to act on floats or vectors of floats.");
866 
867     if (base_type->AsFloat()->width() == 32) {
868       one = const_mgr->GetConstant(base_type,
869                                    utils::FloatProxy<float>(1.0f).GetWords());
870     } else {
871       one = const_mgr->GetConstant(base_type,
872                                    utils::FloatProxy<double>(1.0).GetWords());
873     }
874 
875     if (is_vector) {
876       uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
877       one =
878           const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
879     }
880 
881     const analysis::Constant* temp1 = FoldFPBinaryOp(
882         FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
883     if (temp1 == nullptr) {
884       return nullptr;
885     }
886 
887     const analysis::Constant* temp2 = FoldFPBinaryOp(
888         FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
889     if (temp2 == nullptr) {
890       return nullptr;
891     }
892     const analysis::Constant* temp3 =
893         FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
894                        {constants[2], constants[3]}, context);
895     if (temp3 == nullptr) {
896       return nullptr;
897     }
898     return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
899                           context);
900   };
901 }
902 
903 template <class IntType>
FoldIClamp(IntType x,IntType min_val,IntType max_val)904 IntType FoldIClamp(IntType x, IntType min_val, IntType max_val) {
905   if (x < min_val) {
906     x = min_val;
907   }
908   if (x > max_val) {
909     x = max_val;
910   }
911   return x;
912 }
913 
FoldMin(const analysis::Type * result_type,const analysis::Constant * a,const analysis::Constant * b,analysis::ConstantManager *)914 const analysis::Constant* FoldMin(const analysis::Type* result_type,
915                                   const analysis::Constant* a,
916                                   const analysis::Constant* b,
917                                   analysis::ConstantManager*) {
918   if (const analysis::Integer* int_type = result_type->AsInteger()) {
919     if (int_type->width() == 32) {
920       if (int_type->IsSigned()) {
921         int32_t va = a->GetS32();
922         int32_t vb = b->GetS32();
923         return (va < vb ? a : b);
924       } else {
925         uint32_t va = a->GetU32();
926         uint32_t vb = b->GetU32();
927         return (va < vb ? a : b);
928       }
929     } else if (int_type->width() == 64) {
930       if (int_type->IsSigned()) {
931         int64_t va = a->GetS64();
932         int64_t vb = b->GetS64();
933         return (va < vb ? a : b);
934       } else {
935         uint64_t va = a->GetU64();
936         uint64_t vb = b->GetU64();
937         return (va < vb ? a : b);
938       }
939     }
940   } else if (const analysis::Float* float_type = result_type->AsFloat()) {
941     if (float_type->width() == 32) {
942       float va = a->GetFloat();
943       float vb = b->GetFloat();
944       return (va < vb ? a : b);
945     } else if (float_type->width() == 64) {
946       double va = a->GetDouble();
947       double vb = b->GetDouble();
948       return (va < vb ? a : b);
949     }
950   }
951   return nullptr;
952 }
953 
FoldMax(const analysis::Type * result_type,const analysis::Constant * a,const analysis::Constant * b,analysis::ConstantManager *)954 const analysis::Constant* FoldMax(const analysis::Type* result_type,
955                                   const analysis::Constant* a,
956                                   const analysis::Constant* b,
957                                   analysis::ConstantManager*) {
958   if (const analysis::Integer* int_type = result_type->AsInteger()) {
959     if (int_type->width() == 32) {
960       if (int_type->IsSigned()) {
961         int32_t va = a->GetS32();
962         int32_t vb = b->GetS32();
963         return (va > vb ? a : b);
964       } else {
965         uint32_t va = a->GetU32();
966         uint32_t vb = b->GetU32();
967         return (va > vb ? a : b);
968       }
969     } else if (int_type->width() == 64) {
970       if (int_type->IsSigned()) {
971         int64_t va = a->GetS64();
972         int64_t vb = b->GetS64();
973         return (va > vb ? a : b);
974       } else {
975         uint64_t va = a->GetU64();
976         uint64_t vb = b->GetU64();
977         return (va > vb ? a : b);
978       }
979     }
980   } else if (const analysis::Float* float_type = result_type->AsFloat()) {
981     if (float_type->width() == 32) {
982       float va = a->GetFloat();
983       float vb = b->GetFloat();
984       return (va > vb ? a : b);
985     } else if (float_type->width() == 64) {
986       double va = a->GetDouble();
987       double vb = b->GetDouble();
988       return (va > vb ? a : b);
989     }
990   }
991   return nullptr;
992 }
993 
994 // Fold an clamp instruction when all three operands are constant.
FoldClamp1(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)995 const analysis::Constant* FoldClamp1(
996     IRContext* context, Instruction* inst,
997     const std::vector<const analysis::Constant*>& constants) {
998   assert(inst->opcode() == SpvOpExtInst &&
999          "Expecting an extended instruction.");
1000   assert(inst->GetSingleWordInOperand(0) ==
1001              context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1002          "Expecting a GLSLstd450 extended instruction.");
1003 
1004   // Make sure all Clamp operands are constants.
1005   for (uint32_t i = 1; i < 3; i++) {
1006     if (constants[i] == nullptr) {
1007       return nullptr;
1008     }
1009   }
1010 
1011   const analysis::Constant* temp = FoldFPBinaryOp(
1012       FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1013   if (temp == nullptr) {
1014     return nullptr;
1015   }
1016   return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1017                         context);
1018 }
1019 
1020 // Fold a clamp instruction when |x >= min_val|.
FoldClamp2(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1021 const analysis::Constant* FoldClamp2(
1022     IRContext* context, Instruction* inst,
1023     const std::vector<const analysis::Constant*>& constants) {
1024   assert(inst->opcode() == SpvOpExtInst &&
1025          "Expecting an extended instruction.");
1026   assert(inst->GetSingleWordInOperand(0) ==
1027              context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1028          "Expecting a GLSLstd450 extended instruction.");
1029 
1030   const analysis::Constant* x = constants[1];
1031   const analysis::Constant* min_val = constants[2];
1032 
1033   if (x == nullptr || min_val == nullptr) {
1034     return nullptr;
1035   }
1036 
1037   const analysis::Constant* temp =
1038       FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1039   if (temp == min_val) {
1040     // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1041     // result of the max operation is |min_val|, we know the result of the min
1042     // operation, even if |max_val| is not a constant.
1043     return min_val;
1044   }
1045   return nullptr;
1046 }
1047 
1048 // Fold a clamp instruction when |x >= max_val|.
FoldClamp3(IRContext * context,Instruction * inst,const std::vector<const analysis::Constant * > & constants)1049 const analysis::Constant* FoldClamp3(
1050     IRContext* context, Instruction* inst,
1051     const std::vector<const analysis::Constant*>& constants) {
1052   assert(inst->opcode() == SpvOpExtInst &&
1053          "Expecting an extended instruction.");
1054   assert(inst->GetSingleWordInOperand(0) ==
1055              context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1056          "Expecting a GLSLstd450 extended instruction.");
1057 
1058   const analysis::Constant* x = constants[1];
1059   const analysis::Constant* max_val = constants[3];
1060 
1061   if (x == nullptr || max_val == nullptr) {
1062     return nullptr;
1063   }
1064 
1065   const analysis::Constant* temp =
1066       FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1067   if (temp == max_val) {
1068     // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1069     // result of the max operation is |min_val|, we know the result of the min
1070     // operation, even if |max_val| is not a constant.
1071     return max_val;
1072   }
1073   return nullptr;
1074 }
1075 
FoldFTranscendentalUnary(double (* fp)(double))1076 UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
1077   return
1078       [fp](const analysis::Type* result_type, const analysis::Constant* a,
1079            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1080         assert(result_type != nullptr && a != nullptr);
1081         const analysis::Float* float_type = a->type()->AsFloat();
1082         assert(float_type != nullptr);
1083         assert(float_type == result_type->AsFloat());
1084         if (float_type->width() == 32) {
1085           float fa = a->GetFloat();
1086           float res = static_cast<float>(fp(fa));
1087           utils::FloatProxy<float> result(res);
1088           std::vector<uint32_t> words = result.GetWords();
1089           return const_mgr->GetConstant(result_type, words);
1090         } else if (float_type->width() == 64) {
1091           double fa = a->GetDouble();
1092           double res = fp(fa);
1093           utils::FloatProxy<double> result(res);
1094           std::vector<uint32_t> words = result.GetWords();
1095           return const_mgr->GetConstant(result_type, words);
1096         }
1097         return nullptr;
1098       };
1099 }
1100 
FoldFTranscendentalBinary(double (* fp)(double,double))1101 BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1102                                                                double)) {
1103   return
1104       [fp](const analysis::Type* result_type, const analysis::Constant* a,
1105            const analysis::Constant* b,
1106            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1107         assert(result_type != nullptr && a != nullptr);
1108         const analysis::Float* float_type = a->type()->AsFloat();
1109         assert(float_type != nullptr);
1110         assert(float_type == result_type->AsFloat());
1111         assert(float_type == b->type()->AsFloat());
1112         if (float_type->width() == 32) {
1113           float fa = a->GetFloat();
1114           float fb = b->GetFloat();
1115           float res = static_cast<float>(fp(fa, fb));
1116           utils::FloatProxy<float> result(res);
1117           std::vector<uint32_t> words = result.GetWords();
1118           return const_mgr->GetConstant(result_type, words);
1119         } else if (float_type->width() == 64) {
1120           double fa = a->GetDouble();
1121           double fb = b->GetDouble();
1122           double res = fp(fa, fb);
1123           utils::FloatProxy<double> result(res);
1124           std::vector<uint32_t> words = result.GetWords();
1125           return const_mgr->GetConstant(result_type, words);
1126         }
1127         return nullptr;
1128       };
1129 }
1130 }  // namespace
1131 
AddFoldingRules()1132 void ConstantFoldingRules::AddFoldingRules() {
1133   // Add all folding rules to the list for the opcodes to which they apply.
1134   // Note that the order in which rules are added to the list matters. If a rule
1135   // applies to the instruction, the rest of the rules will not be attempted.
1136   // Take that into consideration.
1137 
1138   rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
1139 
1140   rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
1141 
1142   rules_[SpvOpConvertFToS].push_back(FoldFToI());
1143   rules_[SpvOpConvertFToU].push_back(FoldFToI());
1144   rules_[SpvOpConvertSToF].push_back(FoldIToF());
1145   rules_[SpvOpConvertUToF].push_back(FoldIToF());
1146 
1147   rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
1148   rules_[SpvOpFAdd].push_back(FoldFAdd());
1149   rules_[SpvOpFDiv].push_back(FoldFDiv());
1150   rules_[SpvOpFMul].push_back(FoldFMul());
1151   rules_[SpvOpFSub].push_back(FoldFSub());
1152 
1153   rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
1154 
1155   rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
1156 
1157   rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
1158 
1159   rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
1160 
1161   rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
1162   rules_[SpvOpFOrdLessThan].push_back(
1163       FoldFClampFeedingCompare(SpvOpFOrdLessThan));
1164 
1165   rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
1166   rules_[SpvOpFUnordLessThan].push_back(
1167       FoldFClampFeedingCompare(SpvOpFUnordLessThan));
1168 
1169   rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
1170   rules_[SpvOpFOrdGreaterThan].push_back(
1171       FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
1172 
1173   rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
1174   rules_[SpvOpFUnordGreaterThan].push_back(
1175       FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
1176 
1177   rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
1178   rules_[SpvOpFOrdLessThanEqual].push_back(
1179       FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
1180 
1181   rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
1182   rules_[SpvOpFUnordLessThanEqual].push_back(
1183       FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
1184 
1185   rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
1186   rules_[SpvOpFOrdGreaterThanEqual].push_back(
1187       FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
1188 
1189   rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
1190   rules_[SpvOpFUnordGreaterThanEqual].push_back(
1191       FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
1192 
1193   rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
1194   rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
1195 
1196   rules_[SpvOpFNegate].push_back(FoldFNegate());
1197   rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
1198 
1199   // Add rules for GLSLstd450
1200   FeatureManager* feature_manager = context_->get_feature_mgr();
1201   uint32_t ext_inst_glslstd450_id =
1202       feature_manager->GetExtInstImportId_GLSLstd450();
1203   if (ext_inst_glslstd450_id != 0) {
1204     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
1205     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
1206         FoldFPBinaryOp(FoldMin));
1207     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
1208         FoldFPBinaryOp(FoldMin));
1209     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
1210         FoldFPBinaryOp(FoldMin));
1211     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
1212         FoldFPBinaryOp(FoldMax));
1213     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
1214         FoldFPBinaryOp(FoldMax));
1215     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
1216         FoldFPBinaryOp(FoldMax));
1217     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1218         FoldClamp1);
1219     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1220         FoldClamp2);
1221     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1222         FoldClamp3);
1223     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1224         FoldClamp1);
1225     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1226         FoldClamp2);
1227     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1228         FoldClamp3);
1229     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1230         FoldClamp1);
1231     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1232         FoldClamp2);
1233     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1234         FoldClamp3);
1235     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
1236         FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
1237     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
1238         FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
1239     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
1240         FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
1241     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
1242         FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
1243     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
1244         FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
1245     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
1246         FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
1247     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
1248         FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
1249     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
1250         FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
1251 
1252 #ifdef __ANDROID__
1253     // Android NDK r15c tageting ABI 15 doesn't have full support for C++11
1254     // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
1255     // available up until ABI 18 so we use a shim
1256     auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
1257     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1258         FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
1259     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1260         FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
1261 #else
1262     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1263         FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
1264     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1265         FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
1266 #endif
1267 
1268     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
1269         FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
1270     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
1271         FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
1272     ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
1273         FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
1274   }
1275 }
1276 }  // namespace opt
1277 }  // namespace spvtools
1278