1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "fix_storage_class.h"
16 
17 #include <set>
18 
19 #include "source/opt/instruction.h"
20 #include "source/opt/ir_context.h"
21 
22 namespace spvtools {
23 namespace opt {
24 
Process()25 Pass::Status FixStorageClass::Process() {
26   bool modified = false;
27 
28   get_module()->ForEachInst([this, &modified](Instruction* inst) {
29     if (inst->opcode() == SpvOpVariable) {
30       std::set<uint32_t> seen;
31       std::vector<std::pair<Instruction*, uint32_t>> uses;
32       get_def_use_mgr()->ForEachUse(inst,
33                                     [&uses](Instruction* use, uint32_t op_idx) {
34                                       uses.push_back({use, op_idx});
35                                     });
36 
37       for (auto& use : uses) {
38         modified |= PropagateStorageClass(
39             use.first,
40             static_cast<SpvStorageClass>(inst->GetSingleWordInOperand(0)),
41             &seen);
42         assert(seen.empty() && "Seen was not properly reset.");
43         modified |=
44             PropagateType(use.first, inst->type_id(), use.second, &seen);
45         assert(seen.empty() && "Seen was not properly reset.");
46       }
47     }
48   });
49   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
50 }
51 
PropagateStorageClass(Instruction * inst,SpvStorageClass storage_class,std::set<uint32_t> * seen)52 bool FixStorageClass::PropagateStorageClass(Instruction* inst,
53                                             SpvStorageClass storage_class,
54                                             std::set<uint32_t>* seen) {
55   if (!IsPointerResultType(inst)) {
56     return false;
57   }
58 
59   if (IsPointerToStorageClass(inst, storage_class)) {
60     if (inst->opcode() == SpvOpPhi) {
61       if (!seen->insert(inst->result_id()).second) {
62         return false;
63       }
64     }
65 
66     bool modified = false;
67     std::vector<Instruction*> uses;
68     get_def_use_mgr()->ForEachUser(
69         inst, [&uses](Instruction* use) { uses.push_back(use); });
70     for (Instruction* use : uses) {
71       modified |= PropagateStorageClass(use, storage_class, seen);
72     }
73 
74     if (inst->opcode() == SpvOpPhi) {
75       seen->erase(inst->result_id());
76     }
77     return modified;
78   }
79 
80   switch (inst->opcode()) {
81     case SpvOpAccessChain:
82     case SpvOpPtrAccessChain:
83     case SpvOpInBoundsAccessChain:
84     case SpvOpCopyObject:
85     case SpvOpPhi:
86     case SpvOpSelect:
87       FixInstructionStorageClass(inst, storage_class, seen);
88       return true;
89     case SpvOpFunctionCall:
90       // We cannot be sure of the actual connection between the storage class
91       // of the parameter and the storage class of the result, so we should not
92       // do anything.  If the result type needs to be fixed, the function call
93       // should be inlined.
94       return false;
95     case SpvOpImageTexelPointer:
96     case SpvOpLoad:
97     case SpvOpStore:
98     case SpvOpCopyMemory:
99     case SpvOpCopyMemorySized:
100     case SpvOpVariable:
101     case SpvOpBitcast:
102       // Nothing to change for these opcode.  The result type is the same
103       // regardless of the storage class of the operand.
104       return false;
105     default:
106       assert(false &&
107              "Not expecting instruction to have a pointer result type.");
108       return false;
109   }
110 }
111 
FixInstructionStorageClass(Instruction * inst,SpvStorageClass storage_class,std::set<uint32_t> * seen)112 void FixStorageClass::FixInstructionStorageClass(Instruction* inst,
113                                                  SpvStorageClass storage_class,
114                                                  std::set<uint32_t>* seen) {
115   assert(IsPointerResultType(inst) &&
116          "The result type of the instruction must be a pointer.");
117 
118   ChangeResultStorageClass(inst, storage_class);
119 
120   std::vector<Instruction*> uses;
121   get_def_use_mgr()->ForEachUser(
122       inst, [&uses](Instruction* use) { uses.push_back(use); });
123   for (Instruction* use : uses) {
124     PropagateStorageClass(use, storage_class, seen);
125   }
126 }
127 
ChangeResultStorageClass(Instruction * inst,SpvStorageClass storage_class) const128 void FixStorageClass::ChangeResultStorageClass(
129     Instruction* inst, SpvStorageClass storage_class) const {
130   analysis::TypeManager* type_mgr = context()->get_type_mgr();
131   Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
132   assert(result_type_inst->opcode() == SpvOpTypePointer);
133   uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
134   uint32_t new_result_type_id =
135       type_mgr->FindPointerToType(pointee_type_id, storage_class);
136   inst->SetResultType(new_result_type_id);
137   context()->UpdateDefUse(inst);
138 }
139 
IsPointerResultType(Instruction * inst)140 bool FixStorageClass::IsPointerResultType(Instruction* inst) {
141   if (inst->type_id() == 0) {
142     return false;
143   }
144   const analysis::Type* ret_type =
145       context()->get_type_mgr()->GetType(inst->type_id());
146   return ret_type->AsPointer() != nullptr;
147 }
148 
IsPointerToStorageClass(Instruction * inst,SpvStorageClass storage_class)149 bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
150                                               SpvStorageClass storage_class) {
151   analysis::TypeManager* type_mgr = context()->get_type_mgr();
152   analysis::Type* pType = type_mgr->GetType(inst->type_id());
153   const analysis::Pointer* result_type = pType->AsPointer();
154 
155   if (result_type == nullptr) {
156     return false;
157   }
158 
159   return (result_type->storage_class() == storage_class);
160 }
161 
ChangeResultType(Instruction * inst,uint32_t new_type_id)162 bool FixStorageClass::ChangeResultType(Instruction* inst,
163                                        uint32_t new_type_id) {
164   if (inst->type_id() == new_type_id) {
165     return false;
166   }
167 
168   context()->ForgetUses(inst);
169   inst->SetResultType(new_type_id);
170   context()->AnalyzeUses(inst);
171   return true;
172 }
173 
PropagateType(Instruction * inst,uint32_t type_id,uint32_t op_idx,std::set<uint32_t> * seen)174 bool FixStorageClass::PropagateType(Instruction* inst, uint32_t type_id,
175                                     uint32_t op_idx, std::set<uint32_t>* seen) {
176   assert(type_id != 0 && "Not given a valid type in PropagateType");
177   bool modified = false;
178 
179   // If the type of operand |op_idx| forces the result type of |inst| to a
180   // particular type, then we want find that type.
181   uint32_t new_type_id = 0;
182   switch (inst->opcode()) {
183     case SpvOpAccessChain:
184     case SpvOpPtrAccessChain:
185     case SpvOpInBoundsAccessChain:
186     case SpvOpInBoundsPtrAccessChain:
187       if (op_idx == 2) {
188         new_type_id = WalkAccessChainType(inst, type_id);
189       }
190       break;
191     case SpvOpCopyObject:
192       new_type_id = type_id;
193       break;
194     case SpvOpPhi:
195       if (seen->insert(inst->result_id()).second) {
196         new_type_id = type_id;
197       }
198       break;
199     case SpvOpSelect:
200       if (op_idx > 2) {
201         new_type_id = type_id;
202       }
203       break;
204     case SpvOpFunctionCall:
205       // We cannot be sure of the actual connection between the type
206       // of the parameter and the type of the result, so we should not
207       // do anything.  If the result type needs to be fixed, the function call
208       // should be inlined.
209       return false;
210     case SpvOpLoad: {
211       Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
212       new_type_id = type_inst->GetSingleWordInOperand(1);
213       break;
214     }
215     case SpvOpStore: {
216       uint32_t obj_id = inst->GetSingleWordInOperand(1);
217       Instruction* obj_inst = get_def_use_mgr()->GetDef(obj_id);
218       uint32_t obj_type_id = obj_inst->type_id();
219 
220       uint32_t ptr_id = inst->GetSingleWordInOperand(0);
221       Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id);
222       uint32_t pointee_type_id = GetPointeeTypeId(ptr_inst);
223 
224       if (obj_type_id != pointee_type_id) {
225         uint32_t copy_id = GenerateCopy(obj_inst, pointee_type_id, inst);
226         inst->SetInOperand(1, {copy_id});
227         context()->UpdateDefUse(inst);
228       }
229     } break;
230     case SpvOpCopyMemory:
231     case SpvOpCopyMemorySized:
232       // TODO: May need to expand the copy as we do with the stores.
233       break;
234     case SpvOpCompositeConstruct:
235     case SpvOpCompositeExtract:
236     case SpvOpCompositeInsert:
237       // TODO: DXC does not seem to generate code that will require changes to
238       // these opcode.  The can be implemented when they come up.
239       break;
240     case SpvOpImageTexelPointer:
241     case SpvOpBitcast:
242       // Nothing to change for these opcode.  The result type is the same
243       // regardless of the type of the operand.
244       return false;
245     default:
246       // I expect the remaining instructions to act on types that are guaranteed
247       // to be unique, so no change will be necessary.
248       break;
249   }
250 
251   // If the operand forces the result type, then make sure the result type
252   // matches, and update the uses of |inst|.  We do not have to check the uses
253   // of |inst| in the result type is not forced because we are only looking for
254   // issue that come from mismatches between function formal and actual
255   // parameters after the function has been inlined.  These parameters are
256   // pointers. Once the type no longer depends on the type of the parameter,
257   // then the types should have be correct.
258   if (new_type_id != 0) {
259     modified = ChangeResultType(inst, new_type_id);
260 
261     std::vector<std::pair<Instruction*, uint32_t>> uses;
262     get_def_use_mgr()->ForEachUse(inst,
263                                   [&uses](Instruction* use, uint32_t idx) {
264                                     uses.push_back({use, idx});
265                                   });
266 
267     for (auto& use : uses) {
268       PropagateType(use.first, new_type_id, use.second, seen);
269     }
270 
271     if (inst->opcode() == SpvOpPhi) {
272       seen->erase(inst->result_id());
273     }
274   }
275   return modified;
276 }
277 
WalkAccessChainType(Instruction * inst,uint32_t id)278 uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
279   uint32_t start_idx = 0;
280   switch (inst->opcode()) {
281     case SpvOpAccessChain:
282     case SpvOpInBoundsAccessChain:
283       start_idx = 1;
284       break;
285     case SpvOpPtrAccessChain:
286     case SpvOpInBoundsPtrAccessChain:
287       start_idx = 2;
288       break;
289     default:
290       assert(false);
291       break;
292   }
293 
294   Instruction* orig_type_inst = get_def_use_mgr()->GetDef(id);
295   assert(orig_type_inst->opcode() == SpvOpTypePointer);
296   id = orig_type_inst->GetSingleWordInOperand(1);
297 
298   for (uint32_t i = start_idx; i < inst->NumInOperands(); ++i) {
299     Instruction* type_inst = get_def_use_mgr()->GetDef(id);
300     switch (type_inst->opcode()) {
301       case SpvOpTypeArray:
302       case SpvOpTypeRuntimeArray:
303       case SpvOpTypeMatrix:
304       case SpvOpTypeVector:
305         id = type_inst->GetSingleWordInOperand(0);
306         break;
307       case SpvOpTypeStruct: {
308         const analysis::Constant* index_const =
309             context()->get_constant_mgr()->FindDeclaredConstant(
310                 inst->GetSingleWordInOperand(i));
311         uint32_t index = index_const->GetU32();
312         id = type_inst->GetSingleWordInOperand(index);
313         break;
314       }
315       default:
316         break;
317     }
318     assert(id != 0 &&
319            "Tried to extract from an object where it cannot be done.");
320   }
321 
322   return context()->get_type_mgr()->FindPointerToType(
323       id,
324       static_cast<SpvStorageClass>(orig_type_inst->GetSingleWordInOperand(0)));
325 }
326 
327 // namespace opt
328 
329 }  // namespace opt
330 }  // namespace spvtools
331