1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/fold_spec_constant_op_and_composite_pass.h"
16 
17 #include <algorithm>
18 #include <initializer_list>
19 #include <tuple>
20 
21 #include "source/opt/constants.h"
22 #include "source/opt/fold.h"
23 #include "source/opt/ir_context.h"
24 #include "source/util/make_unique.h"
25 
26 namespace spvtools {
27 namespace opt {
28 
Process()29 Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
30   bool modified = false;
31   // Traverse through all the constant defining instructions. For Normal
32   // Constants whose values are determined and do not depend on OpUndef
33   // instructions, records their values in two internal maps: id_to_const_val_
34   // and const_val_to_id_ so that we can use them to infer the value of Spec
35   // Constants later.
36   // For Spec Constants defined with OpSpecConstantComposite instructions, if
37   // all of their components are Normal Constants, they will be turned into
38   // Normal Constants too. For Spec Constants defined with OpSpecConstantOp
39   // instructions, we check if they only depends on Normal Constants and fold
40   // them when possible. The two maps for Normal Constants: id_to_const_val_
41   // and const_val_to_id_ will be updated along the traversal so that the new
42   // Normal Constants generated from folding can be used to fold following Spec
43   // Constants.
44   // This algorithm depends on the SSA property of SPIR-V when
45   // defining constants. The dependent constants must be defined before the
46   // dependee constants. So a dependent Spec Constant must be defined and
47   // will be processed before its dependee Spec Constant. When we encounter
48   // the dependee Spec Constants, all its dependent constants must have been
49   // processed and all its dependent Spec Constants should have been folded if
50   // possible.
51   Module::inst_iterator next_inst = context()->types_values_begin();
52   for (Module::inst_iterator inst_iter = next_inst;
53        // Need to re-evaluate the end iterator since we may modify the list of
54        // instructions in this section of the module as the process goes.
55        inst_iter != context()->types_values_end(); inst_iter = next_inst) {
56     ++next_inst;
57     Instruction* inst = &*inst_iter;
58     // Collect constant values of normal constants and process the
59     // OpSpecConstantOp and OpSpecConstantComposite instructions if possible.
60     // The constant values will be stored in analysis::Constant instances.
61     // OpConstantSampler instruction is not collected here because it cannot be
62     // used in OpSpecConstant{Composite|Op} instructions.
63     // TODO(qining): If the constant or its type has decoration, we may need
64     // to skip it.
65     if (context()->get_constant_mgr()->GetType(inst) &&
66         !context()->get_constant_mgr()->GetType(inst)->decoration_empty())
67       continue;
68     switch (SpvOp opcode = inst->opcode()) {
69       // Records the values of Normal Constants.
70       case SpvOp::SpvOpConstantTrue:
71       case SpvOp::SpvOpConstantFalse:
72       case SpvOp::SpvOpConstant:
73       case SpvOp::SpvOpConstantNull:
74       case SpvOp::SpvOpConstantComposite:
75       case SpvOp::SpvOpSpecConstantComposite: {
76         // A Constant instance will be created if the given instruction is a
77         // Normal Constant whose value(s) are fixed. Note that for a composite
78         // Spec Constant defined with OpSpecConstantComposite instruction, if
79         // all of its components are Normal Constants already, the Spec
80         // Constant will be turned in to a Normal Constant. In that case, a
81         // Constant instance should also be created successfully and recorded
82         // in the id_to_const_val_ and const_val_to_id_ mapps.
83         if (auto const_value =
84                 context()->get_constant_mgr()->GetConstantFromInst(inst)) {
85           // Need to replace the OpSpecConstantComposite instruction with a
86           // corresponding OpConstantComposite instruction.
87           if (opcode == SpvOp::SpvOpSpecConstantComposite) {
88             inst->SetOpcode(SpvOp::SpvOpConstantComposite);
89             modified = true;
90           }
91           context()->get_constant_mgr()->MapConstantToInst(const_value, inst);
92         }
93         break;
94       }
95       // For a Spec Constants defined with OpSpecConstantOp instruction, check
96       // if it only depends on Normal Constants. If so, the Spec Constant will
97       // be folded. The original Spec Constant defining instruction will be
98       // replaced by Normal Constant defining instructions, and the new Normal
99       // Constants will be added to id_to_const_val_ and const_val_to_id_ so
100       // that we can use the new Normal Constants when folding following Spec
101       // Constants.
102       case SpvOp::SpvOpSpecConstantOp:
103         modified |= ProcessOpSpecConstantOp(&inst_iter);
104         break;
105       default:
106         break;
107     }
108   }
109   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
110 }
111 
ProcessOpSpecConstantOp(Module::inst_iterator * pos)112 bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
113     Module::inst_iterator* pos) {
114   Instruction* inst = &**pos;
115   Instruction* folded_inst = nullptr;
116   assert(inst->GetInOperand(0).type ==
117              SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER &&
118          "The first in-operand of OpSpecContantOp instruction must be of "
119          "SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type");
120 
121   switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) {
122     case SpvOp::SpvOpCompositeExtract:
123     case SpvOp::SpvOpVectorShuffle:
124     case SpvOp::SpvOpCompositeInsert:
125     case SpvOp::SpvOpQuantizeToF16:
126       folded_inst = FoldWithInstructionFolder(pos);
127       break;
128     default:
129       // TODO: This should use the instruction folder as well, but some folding
130       // rules are missing.
131 
132       // Component-wise operations.
133       folded_inst = DoComponentWiseOperation(pos);
134       break;
135   }
136   if (!folded_inst) return false;
137 
138   // Replace the original constant with the new folded constant, kill the
139   // original constant.
140   uint32_t new_id = folded_inst->result_id();
141   uint32_t old_id = inst->result_id();
142   context()->ReplaceAllUsesWith(old_id, new_id);
143   context()->KillDef(old_id);
144   return true;
145 }
146 
GetTypeComponent(uint32_t typeId,uint32_t element) const147 uint32_t FoldSpecConstantOpAndCompositePass::GetTypeComponent(
148     uint32_t typeId, uint32_t element) const {
149   Instruction* type = context()->get_def_use_mgr()->GetDef(typeId);
150   uint32_t subtype = type->GetTypeComponent(element);
151   assert(subtype != 0);
152 
153   return subtype;
154 }
155 
FoldWithInstructionFolder(Module::inst_iterator * inst_iter_ptr)156 Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
157     Module::inst_iterator* inst_iter_ptr) {
158   // If one of operands to the instruction is not a
159   // constant, then we cannot fold this spec constant.
160   for (uint32_t i = 1; i < (*inst_iter_ptr)->NumInOperands(); i++) {
161     const Operand& operand = (*inst_iter_ptr)->GetInOperand(i);
162     if (operand.type != SPV_OPERAND_TYPE_ID &&
163         operand.type != SPV_OPERAND_TYPE_OPTIONAL_ID) {
164       continue;
165     }
166     uint32_t id = operand.words[0];
167     if (context()->get_constant_mgr()->FindDeclaredConstant(id) == nullptr) {
168       return nullptr;
169     }
170   }
171 
172   // All of the operands are constant.  Construct a regular version of the
173   // instruction and pass it to the instruction folder.
174   std::unique_ptr<Instruction> inst((*inst_iter_ptr)->Clone(context()));
175   inst->SetOpcode(
176       static_cast<SpvOp>((*inst_iter_ptr)->GetSingleWordInOperand(0)));
177   inst->RemoveOperand(2);
178 
179   // We want the current instruction to be replaced by an |OpConstant*|
180   // instruction in the same position. We need to keep track of which constants
181   // the instruction folder creates, so we can move them into the correct place.
182   auto last_type_value_iter = (context()->types_values_end());
183   --last_type_value_iter;
184   Instruction* last_type_value = &*last_type_value_iter;
185 
186   auto identity_map = [](uint32_t id) { return id; };
187   Instruction* new_const_inst =
188       context()->get_instruction_folder().FoldInstructionToConstant(
189           inst.get(), identity_map);
190   assert(new_const_inst != nullptr &&
191          "Failed to fold instruction that must be folded.");
192 
193   // Get the instruction before |pos| to insert after.  |pos| cannot be the
194   // first instruction in the list because its type has to come first.
195   Instruction* insert_pos = (*inst_iter_ptr)->PreviousNode();
196   assert(insert_pos != nullptr &&
197          "pos is the first instruction in the types and values.");
198   bool need_to_clone = true;
199   for (Instruction* i = last_type_value->NextNode(); i != nullptr;
200        i = last_type_value->NextNode()) {
201     if (i == new_const_inst) {
202       need_to_clone = false;
203     }
204     i->InsertAfter(insert_pos);
205     insert_pos = insert_pos->NextNode();
206   }
207 
208   if (need_to_clone) {
209     new_const_inst = new_const_inst->Clone(context());
210     new_const_inst->SetResultId(TakeNextId());
211     new_const_inst->InsertAfter(insert_pos);
212     get_def_use_mgr()->AnalyzeInstDefUse(new_const_inst);
213   }
214   return new_const_inst;
215 }
216 
DoVectorShuffle(Module::inst_iterator * pos)217 Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
218     Module::inst_iterator* pos) {
219   Instruction* inst = &**pos;
220   analysis::Vector* result_vec_type =
221       context()->get_constant_mgr()->GetType(inst)->AsVector();
222   assert(inst->NumInOperands() - 1 > 2 &&
223          "OpSpecConstantOp DoVectorShuffle instruction requires more than 2 "
224          "operands (2 vector ids and at least one literal operand");
225   assert(result_vec_type &&
226          "The result of VectorShuffle must be of type vector");
227 
228   // A temporary null constants that can be used as the components of the result
229   // vector. This is needed when any one of the vector operands are null
230   // constant.
231   const analysis::Constant* null_component_constants = nullptr;
232 
233   // Get a concatenated vector of scalar constants. The vector should be built
234   // with the components from the first and the second operand of VectorShuffle.
235   std::vector<const analysis::Constant*> concatenated_components;
236   // Note that for OpSpecConstantOp, the second in-operand is the first id
237   // operand. The first in-operand is the spec opcode.
238   for (uint32_t i : {1, 2}) {
239     assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_ID &&
240            "The vector operand must have a SPV_OPERAND_TYPE_ID type");
241     uint32_t operand_id = inst->GetSingleWordInOperand(i);
242     auto operand_const =
243         context()->get_constant_mgr()->FindDeclaredConstant(operand_id);
244     if (!operand_const) return nullptr;
245     const analysis::Type* operand_type = operand_const->type();
246     assert(operand_type->AsVector() &&
247            "The first two operand of VectorShuffle must be of vector type");
248     if (auto vec_const = operand_const->AsVectorConstant()) {
249       // case 1: current operand is a non-null vector constant.
250       concatenated_components.insert(concatenated_components.end(),
251                                      vec_const->GetComponents().begin(),
252                                      vec_const->GetComponents().end());
253     } else if (operand_const->AsNullConstant()) {
254       // case 2: current operand is a null vector constant. Create a temporary
255       // null scalar constant as the component.
256       if (!null_component_constants) {
257         const analysis::Type* component_type =
258             operand_type->AsVector()->element_type();
259         null_component_constants =
260             context()->get_constant_mgr()->GetConstant(component_type, {});
261       }
262       // Append the null scalar consts to the concatenated components
263       // vector.
264       concatenated_components.insert(concatenated_components.end(),
265                                      operand_type->AsVector()->element_count(),
266                                      null_component_constants);
267     } else {
268       // no other valid cases
269       return nullptr;
270     }
271   }
272   // Create null component constants if there are any. The component constants
273   // must be added to the module before the dependee composite constants to
274   // satisfy SSA def-use dominance.
275   if (null_component_constants) {
276     context()->get_constant_mgr()->BuildInstructionAndAddToModule(
277         null_component_constants, pos);
278   }
279   // Create the new vector constant with the selected components.
280   std::vector<const analysis::Constant*> selected_components;
281   for (uint32_t i = 3; i < inst->NumInOperands(); i++) {
282     assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
283            "The literal operand must of type SPV_OPERAND_TYPE_LITERAL_INTEGER");
284     uint32_t literal = inst->GetSingleWordInOperand(i);
285     assert(literal < concatenated_components.size() &&
286            "Literal index out of bound of the concatenated vector");
287     selected_components.push_back(concatenated_components[literal]);
288   }
289   auto new_vec_const = MakeUnique<analysis::VectorConstant>(
290       result_vec_type, selected_components);
291   auto reg_vec_const =
292       context()->get_constant_mgr()->RegisterConstant(std::move(new_vec_const));
293   return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
294       reg_vec_const, pos);
295 }
296 
297 namespace {
298 // A helper function to check the type for component wise operations. Returns
299 // true if the type:
300 //  1) is bool type;
301 //  2) is 32-bit int type;
302 //  3) is vector of bool type;
303 //  4) is vector of 32-bit integer type.
304 // Otherwise returns false.
IsValidTypeForComponentWiseOperation(const analysis::Type * type)305 bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
306   if (type->AsBool()) {
307     return true;
308   } else if (auto* it = type->AsInteger()) {
309     if (it->width() == 32) return true;
310   } else if (auto* vt = type->AsVector()) {
311     if (vt->element_type()->AsBool()) {
312       return true;
313     } else if (auto* vit = vt->element_type()->AsInteger()) {
314       if (vit->width() == 32) return true;
315     }
316   }
317   return false;
318 }
319 
320 // Encodes the integer |value| of in a word vector format appropriate for
321 // representing this value as a operands for a constant definition. Performs
322 // zero-extension/sign-extension/truncation when needed, based on the signess of
323 // the given target type.
324 //
325 // Note: type |type| argument must be either Integer or Bool.
EncodeIntegerAsWords(const analysis::Type & type,uint32_t value)326 utils::SmallVector<uint32_t, 2> EncodeIntegerAsWords(const analysis::Type& type,
327                                                      uint32_t value) {
328   const uint32_t all_ones = ~0;
329   uint32_t bit_width = 0;
330   uint32_t pad_value = 0;
331   bool result_type_signed = false;
332   if (auto* int_ty = type.AsInteger()) {
333     bit_width = int_ty->width();
334     result_type_signed = int_ty->IsSigned();
335     if (result_type_signed && static_cast<int32_t>(value) < 0) {
336       pad_value = all_ones;
337     }
338   } else if (type.AsBool()) {
339     bit_width = 1;
340   } else {
341     assert(false && "type must be Integer or Bool");
342   }
343 
344   assert(bit_width > 0);
345   uint32_t first_word = value;
346   const uint32_t bits_per_word = 32;
347 
348   // Truncate first_word if the |type| has width less than uint32.
349   if (bit_width < bits_per_word) {
350     const uint32_t num_high_bits_to_mask = bits_per_word - bit_width;
351     const bool is_negative_after_truncation =
352         result_type_signed &&
353         utils::IsBitAtPositionSet(first_word, bit_width - 1);
354 
355     if (is_negative_after_truncation) {
356       // Truncate and sign-extend |first_word|. No padding words will be
357       // added and |pad_value| can be left as-is.
358       first_word = utils::SetHighBits(first_word, num_high_bits_to_mask);
359     } else {
360       first_word = utils::ClearHighBits(first_word, num_high_bits_to_mask);
361     }
362   }
363 
364   utils::SmallVector<uint32_t, 2> words = {first_word};
365   for (uint32_t current_bit = bits_per_word; current_bit < bit_width;
366        current_bit += bits_per_word) {
367     words.push_back(pad_value);
368   }
369 
370   return words;
371 }
372 }  // namespace
373 
DoComponentWiseOperation(Module::inst_iterator * pos)374 Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
375     Module::inst_iterator* pos) {
376   const Instruction* inst = &**pos;
377   const analysis::Type* result_type =
378       context()->get_constant_mgr()->GetType(inst);
379   SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
380   // Check and collect operands.
381   std::vector<const analysis::Constant*> operands;
382 
383   if (!std::all_of(
384           inst->cbegin(), inst->cend(), [&operands, this](const Operand& o) {
385             // skip the operands that is not an id.
386             if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID) return true;
387             uint32_t id = o.words.front();
388             if (auto c =
389                     context()->get_constant_mgr()->FindDeclaredConstant(id)) {
390               if (IsValidTypeForComponentWiseOperation(c->type())) {
391                 operands.push_back(c);
392                 return true;
393               }
394             }
395             return false;
396           }))
397     return nullptr;
398 
399   if (result_type->AsInteger() || result_type->AsBool()) {
400     // Scalar operation
401     const uint32_t result_val =
402         context()->get_instruction_folder().FoldScalars(spec_opcode, operands);
403     auto result_const = context()->get_constant_mgr()->GetConstant(
404         result_type, EncodeIntegerAsWords(*result_type, result_val));
405     return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
406         result_const, pos);
407   } else if (result_type->AsVector()) {
408     // Vector operation
409     const analysis::Type* element_type =
410         result_type->AsVector()->element_type();
411     uint32_t num_dims = result_type->AsVector()->element_count();
412     std::vector<uint32_t> result_vec =
413         context()->get_instruction_folder().FoldVectors(spec_opcode, num_dims,
414                                                         operands);
415     std::vector<const analysis::Constant*> result_vector_components;
416     for (const uint32_t r : result_vec) {
417       if (auto rc = context()->get_constant_mgr()->GetConstant(
418               element_type, EncodeIntegerAsWords(*element_type, r))) {
419         result_vector_components.push_back(rc);
420         if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule(
421                 rc, pos)) {
422           assert(false &&
423                  "Failed to build and insert constant declaring instruction "
424                  "for the given vector component constant");
425         }
426       } else {
427         assert(false && "Failed to create constants with 32-bit word");
428       }
429     }
430     auto new_vec_const = MakeUnique<analysis::VectorConstant>(
431         result_type->AsVector(), result_vector_components);
432     auto reg_vec_const = context()->get_constant_mgr()->RegisterConstant(
433         std::move(new_vec_const));
434     return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
435         reg_vec_const, pos);
436   } else {
437     // Cannot process invalid component wise operation. The result of component
438     // wise operation must be of integer or bool scalar or vector of
439     // integer/bool type.
440     return nullptr;
441   }
442 }
443 
444 }  // namespace opt
445 }  // namespace spvtools
446