1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/constants.h"
16 
17 #include <unordered_map>
18 #include <vector>
19 
20 #include "source/opt/ir_context.h"
21 
22 namespace spvtools {
23 namespace opt {
24 namespace analysis {
25 
GetFloat() const26 float Constant::GetFloat() const {
27   assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32);
28 
29   if (const FloatConstant* fc = AsFloatConstant()) {
30     return fc->GetFloatValue();
31   } else {
32     assert(AsNullConstant() && "Must be a floating point constant.");
33     return 0.0f;
34   }
35 }
36 
GetDouble() const37 double Constant::GetDouble() const {
38   assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64);
39 
40   if (const FloatConstant* fc = AsFloatConstant()) {
41     return fc->GetDoubleValue();
42   } else {
43     assert(AsNullConstant() && "Must be a floating point constant.");
44     return 0.0;
45   }
46 }
47 
GetValueAsDouble() const48 double Constant::GetValueAsDouble() const {
49   assert(type()->AsFloat() != nullptr);
50   if (type()->AsFloat()->width() == 32) {
51     return GetFloat();
52   } else {
53     assert(type()->AsFloat()->width() == 64);
54     return GetDouble();
55   }
56 }
57 
GetU32() const58 uint32_t Constant::GetU32() const {
59   assert(type()->AsInteger() != nullptr);
60   assert(type()->AsInteger()->width() == 32);
61 
62   if (const IntConstant* ic = AsIntConstant()) {
63     return ic->GetU32BitValue();
64   } else {
65     assert(AsNullConstant() && "Must be an integer constant.");
66     return 0u;
67   }
68 }
69 
GetU64() const70 uint64_t Constant::GetU64() const {
71   assert(type()->AsInteger() != nullptr);
72   assert(type()->AsInteger()->width() == 64);
73 
74   if (const IntConstant* ic = AsIntConstant()) {
75     return ic->GetU64BitValue();
76   } else {
77     assert(AsNullConstant() && "Must be an integer constant.");
78     return 0u;
79   }
80 }
81 
GetS32() const82 int32_t Constant::GetS32() const {
83   assert(type()->AsInteger() != nullptr);
84   assert(type()->AsInteger()->width() == 32);
85 
86   if (const IntConstant* ic = AsIntConstant()) {
87     return ic->GetS32BitValue();
88   } else {
89     assert(AsNullConstant() && "Must be an integer constant.");
90     return 0;
91   }
92 }
93 
GetS64() const94 int64_t Constant::GetS64() const {
95   assert(type()->AsInteger() != nullptr);
96   assert(type()->AsInteger()->width() == 64);
97 
98   if (const IntConstant* ic = AsIntConstant()) {
99     return ic->GetS64BitValue();
100   } else {
101     assert(AsNullConstant() && "Must be an integer constant.");
102     return 0;
103   }
104 }
105 
GetZeroExtendedValue() const106 uint64_t Constant::GetZeroExtendedValue() const {
107   const auto* int_type = type()->AsInteger();
108   assert(int_type != nullptr);
109   const auto width = int_type->width();
110   assert(width <= 64);
111 
112   uint64_t value = 0;
113   if (const IntConstant* ic = AsIntConstant()) {
114     if (width <= 32) {
115       value = ic->GetU32BitValue();
116     } else {
117       value = ic->GetU64BitValue();
118     }
119   } else {
120     assert(AsNullConstant() && "Must be an integer constant.");
121   }
122   return value;
123 }
124 
GetSignExtendedValue() const125 int64_t Constant::GetSignExtendedValue() const {
126   const auto* int_type = type()->AsInteger();
127   assert(int_type != nullptr);
128   const auto width = int_type->width();
129   assert(width <= 64);
130 
131   int64_t value = 0;
132   if (const IntConstant* ic = AsIntConstant()) {
133     if (width <= 32) {
134       // Let the C++ compiler do the sign extension.
135       value = int64_t(ic->GetS32BitValue());
136     } else {
137       value = ic->GetS64BitValue();
138     }
139   } else {
140     assert(AsNullConstant() && "Must be an integer constant.");
141   }
142   return value;
143 }
144 
ConstantManager(IRContext * ctx)145 ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) {
146   // Populate the constant table with values from constant declarations in the
147   // module.  The values of each OpConstant declaration is the identity
148   // assignment (i.e., each constant is its own value).
149   for (const auto& inst : ctx_->module()->GetConstants()) {
150     MapInst(inst);
151   }
152 }
153 
GetType(const Instruction * inst) const154 Type* ConstantManager::GetType(const Instruction* inst) const {
155   return context()->get_type_mgr()->GetType(inst->type_id());
156 }
157 
GetOperandConstants(const Instruction * inst) const158 std::vector<const Constant*> ConstantManager::GetOperandConstants(
159     const Instruction* inst) const {
160   std::vector<const Constant*> constants;
161   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
162     const Operand* operand = &inst->GetInOperand(i);
163     if (operand->type != SPV_OPERAND_TYPE_ID) {
164       constants.push_back(nullptr);
165     } else {
166       uint32_t id = operand->words[0];
167       const analysis::Constant* constant = FindDeclaredConstant(id);
168       constants.push_back(constant);
169     }
170   }
171   return constants;
172 }
173 
FindDeclaredConstant(const Constant * c,uint32_t type_id) const174 uint32_t ConstantManager::FindDeclaredConstant(const Constant* c,
175                                                uint32_t type_id) const {
176   c = FindConstant(c);
177   if (c == nullptr) {
178     return 0;
179   }
180 
181   for (auto range = const_val_to_id_.equal_range(c);
182        range.first != range.second; ++range.first) {
183     Instruction* const_def =
184         context()->get_def_use_mgr()->GetDef(range.first->second);
185     if (type_id == 0 || const_def->type_id() == type_id) {
186       return range.first->second;
187     }
188   }
189   return 0;
190 }
191 
GetConstantsFromIds(const std::vector<uint32_t> & ids) const192 std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
193     const std::vector<uint32_t>& ids) const {
194   std::vector<const Constant*> constants;
195   for (uint32_t id : ids) {
196     if (const Constant* c = FindDeclaredConstant(id)) {
197       constants.push_back(c);
198     } else {
199       return {};
200     }
201   }
202   return constants;
203 }
204 
BuildInstructionAndAddToModule(const Constant * new_const,Module::inst_iterator * pos,uint32_t type_id)205 Instruction* ConstantManager::BuildInstructionAndAddToModule(
206     const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) {
207   // TODO(1841): Handle id overflow.
208   uint32_t new_id = context()->TakeNextId();
209   if (new_id == 0) {
210     return nullptr;
211   }
212 
213   auto new_inst = CreateInstruction(new_id, new_const, type_id);
214   if (!new_inst) {
215     return nullptr;
216   }
217   auto* new_inst_ptr = new_inst.get();
218   *pos = pos->InsertBefore(std::move(new_inst));
219   ++(*pos);
220   context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr);
221   MapConstantToInst(new_const, new_inst_ptr);
222   return new_inst_ptr;
223 }
224 
GetDefiningInstruction(const Constant * c,uint32_t type_id,Module::inst_iterator * pos)225 Instruction* ConstantManager::GetDefiningInstruction(
226     const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
227   uint32_t decl_id = FindDeclaredConstant(c, type_id);
228   if (decl_id == 0) {
229     auto iter = context()->types_values_end();
230     if (pos == nullptr) pos = &iter;
231     return BuildInstructionAndAddToModule(c, pos, type_id);
232   } else {
233     auto def = context()->get_def_use_mgr()->GetDef(decl_id);
234     assert(def != nullptr);
235     assert((type_id == 0 || def->type_id() == type_id) &&
236            "This constant already has an instruction with a different type.");
237     return def;
238   }
239 }
240 
CreateConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids) const241 std::unique_ptr<Constant> ConstantManager::CreateConstant(
242     const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const {
243   if (literal_words_or_ids.size() == 0) {
244     // Constant declared with OpConstantNull
245     return MakeUnique<NullConstant>(type);
246   } else if (auto* bt = type->AsBool()) {
247     assert(literal_words_or_ids.size() == 1 &&
248            "Bool constant should be declared with one operand");
249     return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front());
250   } else if (auto* it = type->AsInteger()) {
251     return MakeUnique<IntConstant>(it, literal_words_or_ids);
252   } else if (auto* ft = type->AsFloat()) {
253     return MakeUnique<FloatConstant>(ft, literal_words_or_ids);
254   } else if (auto* vt = type->AsVector()) {
255     auto components = GetConstantsFromIds(literal_words_or_ids);
256     if (components.empty()) return nullptr;
257     // All components of VectorConstant must be of type Bool, Integer or Float.
258     if (!std::all_of(components.begin(), components.end(),
259                      [](const Constant* c) {
260                        if (c->type()->AsBool() || c->type()->AsInteger() ||
261                            c->type()->AsFloat()) {
262                          return true;
263                        } else {
264                          return false;
265                        }
266                      }))
267       return nullptr;
268     // All components of VectorConstant must be in the same type.
269     const auto* component_type = components.front()->type();
270     if (!std::all_of(components.begin(), components.end(),
271                      [&component_type](const Constant* c) {
272                        if (c->type() == component_type) return true;
273                        return false;
274                      }))
275       return nullptr;
276     return MakeUnique<VectorConstant>(vt, components);
277   } else if (auto* mt = type->AsMatrix()) {
278     auto components = GetConstantsFromIds(literal_words_or_ids);
279     if (components.empty()) return nullptr;
280     return MakeUnique<MatrixConstant>(mt, components);
281   } else if (auto* st = type->AsStruct()) {
282     auto components = GetConstantsFromIds(literal_words_or_ids);
283     if (components.empty()) return nullptr;
284     return MakeUnique<StructConstant>(st, components);
285   } else if (auto* at = type->AsArray()) {
286     auto components = GetConstantsFromIds(literal_words_or_ids);
287     if (components.empty()) return nullptr;
288     return MakeUnique<ArrayConstant>(at, components);
289   } else {
290     return nullptr;
291   }
292 }
293 
GetConstantFromInst(const Instruction * inst)294 const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
295   std::vector<uint32_t> literal_words_or_ids;
296 
297   // Collect the constant defining literals or component ids.
298   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
299     literal_words_or_ids.insert(literal_words_or_ids.end(),
300                                 inst->GetInOperand(i).words.begin(),
301                                 inst->GetInOperand(i).words.end());
302   }
303 
304   switch (inst->opcode()) {
305     // OpConstant{True|False} have the value embedded in the opcode. So they
306     // are not handled by the for-loop above. Here we add the value explicitly.
307     case SpvOp::SpvOpConstantTrue:
308       literal_words_or_ids.push_back(true);
309       break;
310     case SpvOp::SpvOpConstantFalse:
311       literal_words_or_ids.push_back(false);
312       break;
313     case SpvOp::SpvOpConstantNull:
314     case SpvOp::SpvOpConstant:
315     case SpvOp::SpvOpConstantComposite:
316     case SpvOp::SpvOpSpecConstantComposite:
317       break;
318     default:
319       return nullptr;
320   }
321 
322   return GetConstant(GetType(inst), literal_words_or_ids);
323 }
324 
CreateInstruction(uint32_t id,const Constant * c,uint32_t type_id) const325 std::unique_ptr<Instruction> ConstantManager::CreateInstruction(
326     uint32_t id, const Constant* c, uint32_t type_id) const {
327   uint32_t type =
328       (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
329   if (c->AsNullConstant()) {
330     return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantNull, type,
331                                    id, std::initializer_list<Operand>{});
332   } else if (const BoolConstant* bc = c->AsBoolConstant()) {
333     return MakeUnique<Instruction>(
334         context(),
335         bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
336         type, id, std::initializer_list<Operand>{});
337   } else if (const IntConstant* ic = c->AsIntConstant()) {
338     return MakeUnique<Instruction>(
339         context(), SpvOp::SpvOpConstant, type, id,
340         std::initializer_list<Operand>{
341             Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
342                     ic->words())});
343   } else if (const FloatConstant* fc = c->AsFloatConstant()) {
344     return MakeUnique<Instruction>(
345         context(), SpvOp::SpvOpConstant, type, id,
346         std::initializer_list<Operand>{
347             Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
348                     fc->words())});
349   } else if (const CompositeConstant* cc = c->AsCompositeConstant()) {
350     return CreateCompositeInstruction(id, cc, type_id);
351   } else {
352     return nullptr;
353   }
354 }
355 
CreateCompositeInstruction(uint32_t result_id,const CompositeConstant * cc,uint32_t type_id) const356 std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
357     uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
358   std::vector<Operand> operands;
359   Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
360   uint32_t component_index = 0;
361   for (const Constant* component_const : cc->GetComponents()) {
362     uint32_t component_type_id = 0;
363     if (type_inst && type_inst->opcode() == SpvOpTypeStruct) {
364       component_type_id = type_inst->GetSingleWordInOperand(component_index);
365     } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) {
366       component_type_id = type_inst->GetSingleWordInOperand(0);
367     }
368     uint32_t id = FindDeclaredConstant(component_const, component_type_id);
369 
370     if (id == 0) {
371       // Cannot get the id of the component constant, while all components
372       // should have been added to the module prior to the composite constant.
373       // Cannot create OpConstantComposite instruction in this case.
374       return nullptr;
375     }
376     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
377                           std::initializer_list<uint32_t>{id});
378     component_index++;
379   }
380   uint32_t type =
381       (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;
382   return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantComposite, type,
383                                  result_id, std::move(operands));
384 }
385 
GetConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids)386 const Constant* ConstantManager::GetConstant(
387     const Type* type, const std::vector<uint32_t>& literal_words_or_ids) {
388   auto cst = CreateConstant(type, literal_words_or_ids);
389   return cst ? RegisterConstant(std::move(cst)) : nullptr;
390 }
391 
GetFloatConst(float val)392 uint32_t ConstantManager::GetFloatConst(float val) {
393   Type* float_type = context()->get_type_mgr()->GetFloatType();
394   utils::FloatProxy<float> v(val);
395   const Constant* c = GetConstant(float_type, v.GetWords());
396   return GetDefiningInstruction(c)->result_id();
397 }
398 
GetSIntConst(int32_t val)399 uint32_t ConstantManager::GetSIntConst(int32_t val) {
400   Type* sint_type = context()->get_type_mgr()->GetSIntType();
401   const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)});
402   return GetDefiningInstruction(c)->result_id();
403 }
404 
GetVectorComponents(analysis::ConstantManager * const_mgr) const405 std::vector<const analysis::Constant*> Constant::GetVectorComponents(
406     analysis::ConstantManager* const_mgr) const {
407   std::vector<const analysis::Constant*> components;
408   const analysis::VectorConstant* a = this->AsVectorConstant();
409   const analysis::Vector* vector_type = this->type()->AsVector();
410   assert(vector_type != nullptr);
411   if (a != nullptr) {
412     for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
413       components.push_back(a->GetComponents()[i]);
414     }
415   } else {
416     const analysis::Type* element_type = vector_type->element_type();
417     const analysis::Constant* element_null_const =
418         const_mgr->GetConstant(element_type, {});
419     for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
420       components.push_back(element_null_const);
421     }
422   }
423   return components;
424 }
425 
426 }  // namespace analysis
427 }  // namespace opt
428 }  // namespace spvtools
429