1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/desc_sroa.h"
16 
17 #include "source/util/string_utils.h"
18 
19 namespace spvtools {
20 namespace opt {
21 
Process()22 Pass::Status DescriptorScalarReplacement::Process() {
23   bool modified = false;
24 
25   std::vector<Instruction*> vars_to_kill;
26 
27   for (Instruction& var : context()->types_values()) {
28     if (IsCandidate(&var)) {
29       modified = true;
30       if (!ReplaceCandidate(&var)) {
31         return Status::Failure;
32       }
33       vars_to_kill.push_back(&var);
34     }
35   }
36 
37   for (Instruction* var : vars_to_kill) {
38     context()->KillInst(var);
39   }
40 
41   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
42 }
43 
IsCandidate(Instruction * var)44 bool DescriptorScalarReplacement::IsCandidate(Instruction* var) {
45   if (var->opcode() != SpvOpVariable) {
46     return false;
47   }
48 
49   uint32_t ptr_type_id = var->type_id();
50   Instruction* ptr_type_inst =
51       context()->get_def_use_mgr()->GetDef(ptr_type_id);
52   if (ptr_type_inst->opcode() != SpvOpTypePointer) {
53     return false;
54   }
55 
56   uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
57   Instruction* var_type_inst =
58       context()->get_def_use_mgr()->GetDef(var_type_id);
59   if (var_type_inst->opcode() != SpvOpTypeArray) {
60     return false;
61   }
62 
63   bool has_desc_set_decoration = false;
64   context()->get_decoration_mgr()->ForEachDecoration(
65       var->result_id(), SpvDecorationDescriptorSet,
66       [&has_desc_set_decoration](const Instruction&) {
67         has_desc_set_decoration = true;
68       });
69   if (!has_desc_set_decoration) {
70     return false;
71   }
72 
73   bool has_binding_decoration = false;
74   context()->get_decoration_mgr()->ForEachDecoration(
75       var->result_id(), SpvDecorationBinding,
76       [&has_binding_decoration](const Instruction&) {
77         has_binding_decoration = true;
78       });
79   if (!has_binding_decoration) {
80     return false;
81   }
82 
83   return true;
84 }
85 
ReplaceCandidate(Instruction * var)86 bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
87   std::vector<Instruction*> work_list;
88   bool failed = !get_def_use_mgr()->WhileEachUser(
89       var->result_id(), [this, &work_list](Instruction* use) {
90         if (use->opcode() == SpvOpName) {
91           return true;
92         }
93 
94         if (use->IsDecoration()) {
95           return true;
96         }
97 
98         switch (use->opcode()) {
99           case SpvOpAccessChain:
100           case SpvOpInBoundsAccessChain:
101             work_list.push_back(use);
102             return true;
103           default:
104             context()->EmitErrorMessage(
105                 "Variable cannot be replaced: invalid instruction", use);
106             return false;
107         }
108         return true;
109       });
110 
111   if (failed) {
112     return false;
113   }
114 
115   for (Instruction* use : work_list) {
116     if (!ReplaceAccessChain(var, use)) {
117       return false;
118     }
119   }
120   return true;
121 }
122 
ReplaceAccessChain(Instruction * var,Instruction * use)123 bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
124                                                      Instruction* use) {
125   if (use->NumInOperands() <= 1) {
126     context()->EmitErrorMessage(
127         "Variable cannot be replaced: invalid instruction", use);
128     return false;
129   }
130 
131   uint32_t idx_id = use->GetSingleWordInOperand(1);
132   const analysis::Constant* idx_const =
133       context()->get_constant_mgr()->FindDeclaredConstant(idx_id);
134   if (idx_const == nullptr) {
135     context()->EmitErrorMessage("Variable cannot be replaced: invalid index",
136                                 use);
137     return false;
138   }
139 
140   uint32_t idx = idx_const->GetU32();
141   uint32_t replacement_var = GetReplacementVariable(var, idx);
142 
143   if (use->NumInOperands() == 2) {
144     // We are not indexing into the replacement variable.  We can replaces the
145     // access chain with the replacement varibale itself.
146     context()->ReplaceAllUsesWith(use->result_id(), replacement_var);
147     context()->KillInst(use);
148     return true;
149   }
150 
151   // We need to build a new access chain with the replacement variable as the
152   // base address.
153   Instruction::OperandList new_operands;
154 
155   // Same result id and result type.
156   new_operands.emplace_back(use->GetOperand(0));
157   new_operands.emplace_back(use->GetOperand(1));
158 
159   // Use the replacement variable as the base address.
160   new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var}});
161 
162   // Drop the first index because it is consumed by the replacment, and copy the
163   // rest.
164   for (uint32_t i = 4; i < use->NumOperands(); i++) {
165     new_operands.emplace_back(use->GetOperand(i));
166   }
167 
168   use->ReplaceOperands(new_operands);
169   context()->UpdateDefUse(use);
170   return true;
171 }
172 
GetReplacementVariable(Instruction * var,uint32_t idx)173 uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var,
174                                                              uint32_t idx) {
175   auto replacement_vars = replacement_variables_.find(var);
176   if (replacement_vars == replacement_variables_.end()) {
177     uint32_t ptr_type_id = var->type_id();
178     Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
179     assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
180            "Variable should be a pointer to an array.");
181     uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1);
182     Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id);
183     assert(arr_type_inst->opcode() == SpvOpTypeArray &&
184            "Variable should be a pointer to an array.");
185 
186     uint32_t array_len_id = arr_type_inst->GetSingleWordInOperand(1);
187     const analysis::Constant* array_len_const =
188         context()->get_constant_mgr()->FindDeclaredConstant(array_len_id);
189     assert(array_len_const != nullptr && "Array length must be a constant.");
190     uint32_t array_len = array_len_const->GetU32();
191 
192     replacement_vars = replacement_variables_
193                            .insert({var, std::vector<uint32_t>(array_len, 0)})
194                            .first;
195   }
196 
197   if (replacement_vars->second[idx] == 0) {
198     replacement_vars->second[idx] = CreateReplacementVariable(var, idx);
199   }
200 
201   return replacement_vars->second[idx];
202 }
203 
CreateReplacementVariable(Instruction * var,uint32_t idx)204 uint32_t DescriptorScalarReplacement::CreateReplacementVariable(
205     Instruction* var, uint32_t idx) {
206   // The storage class for the new variable is the same as the original.
207   SpvStorageClass storage_class =
208       static_cast<SpvStorageClass>(var->GetSingleWordInOperand(0));
209 
210   // The type for the new variable will be a pointer to type of the elements of
211   // the array.
212   uint32_t ptr_type_id = var->type_id();
213   Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
214   assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
215          "Variable should be a pointer to an array.");
216   uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1);
217   Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id);
218   assert(arr_type_inst->opcode() == SpvOpTypeArray &&
219          "Variable should be a pointer to an array.");
220   uint32_t element_type_id = arr_type_inst->GetSingleWordInOperand(0);
221 
222   uint32_t ptr_element_type_id = context()->get_type_mgr()->FindPointerToType(
223       element_type_id, storage_class);
224 
225   // Create the variable.
226   uint32_t id = TakeNextId();
227   std::unique_ptr<Instruction> variable(
228       new Instruction(context(), SpvOpVariable, ptr_element_type_id, id,
229                       std::initializer_list<Operand>{
230                           {SPV_OPERAND_TYPE_STORAGE_CLASS,
231                            {static_cast<uint32_t>(storage_class)}}}));
232   context()->AddGlobalValue(std::move(variable));
233 
234   // Copy all of the decorations to the new variable.  The only difference is
235   // the Binding decoration needs to be adjusted.
236   for (auto old_decoration :
237        get_decoration_mgr()->GetDecorationsFor(var->result_id(), true)) {
238     assert(old_decoration->opcode() == SpvOpDecorate);
239     std::unique_ptr<Instruction> new_decoration(
240         old_decoration->Clone(context()));
241     new_decoration->SetInOperand(0, {id});
242 
243     uint32_t decoration = new_decoration->GetSingleWordInOperand(1u);
244     if (decoration == SpvDecorationBinding) {
245       uint32_t new_binding = new_decoration->GetSingleWordInOperand(2) + idx;
246       new_decoration->SetInOperand(2, {new_binding});
247     }
248     context()->AddAnnotationInst(std::move(new_decoration));
249   }
250 
251   // Create a new OpName for the replacement variable.
252   for (auto p : context()->GetNames(var->result_id())) {
253     Instruction* name_inst = p.second;
254     std::string name_str = utils::MakeString(name_inst->GetOperand(1).words);
255     name_str += "[";
256     name_str += utils::ToString(idx);
257     name_str += "]";
258 
259     std::unique_ptr<Instruction> new_name(new Instruction(
260         context(), SpvOpName, 0, 0,
261         std::initializer_list<Operand>{
262             {SPV_OPERAND_TYPE_ID, {id}},
263             {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}}));
264     Instruction* new_name_inst = new_name.get();
265     context()->AddDebug2Inst(std::move(new_name));
266     get_def_use_mgr()->AnalyzeInstDefUse(new_name_inst);
267   }
268 
269   return id;
270 }
271 
272 }  // namespace opt
273 }  // namespace spvtools
274