1 /*
2  * Copyright 2018 WebAssembly Community Group participants
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //
18 // Optimizes call arguments in a whole-program manner, removing ones
19 // that are not used (dead).
20 //
21 // Specifically, this does these things:
22 //
23 //  * Find functions for whom an argument is always passed the same
24 //    constant. If so, we can just set that local to that constant
25 //    in the function.
26 //  * Find functions that don't use the value passed to an argument.
27 //    If so, we can avoid even sending and receiving it. (Note how if
28 //    the previous point was true for an argument, then the second
29 //    must as well.)
30 //  * Find return values ("return arguments" ;) that are never used.
31 //
32 // This pass does not depend on flattening, but it may be more effective,
33 // as then call arguments never have side effects (which we need to
34 // watch for here).
35 //
36 
37 #include <unordered_map>
38 #include <unordered_set>
39 
40 #include "cfg/cfg-traversal.h"
41 #include "ir/effects.h"
42 #include "ir/module-utils.h"
43 #include "pass.h"
44 #include "passes/opt-utils.h"
45 #include "support/sorted_vector.h"
46 #include "wasm-builder.h"
47 #include "wasm.h"
48 
49 namespace wasm {
50 
51 // Information for a function
52 struct DAEFunctionInfo {
53   // The unused parameters, if any.
54   SortedVector unusedParams;
55   // Maps a function name to the calls going to it.
56   std::unordered_map<Name, std::vector<Call*>> calls;
57   // Map of all calls that are dropped, to their drops' locations (so that
58   // if we can optimize out the drop, we can replace the drop there).
59   std::unordered_map<Call*, Expression**> droppedCalls;
60   // Whether this function contains any tail calls (including indirect tail
61   // calls) and the set of functions this function tail calls. Tail-callers and
62   // tail-callees cannot have their dropped returns removed because of the
63   // constraint that tail-callees must have the same return type as
64   // tail-callers. Indirectly tail called functions are already not optimized
65   // because being in a table inhibits DAE. TODO: Allow the removal of dropped
66   // returns from tail-callers if their tail-callees can have their returns
67   // removed as well.
68   bool hasTailCalls = false;
69   std::unordered_set<Name> tailCallees;
70   // Whether the function can be called from places that
71   // affect what we can do. For now, any call we don't
72   // see inhibits our optimizations, but TODO: an export
73   // could be worked around by exporting a thunk that
74   // adds the parameter.
75   bool hasUnseenCalls = false;
76 };
77 
78 typedef std::unordered_map<Name, DAEFunctionInfo> DAEFunctionInfoMap;
79 
80 // Information in a basic block
81 struct DAEBlockInfo {
82   // A local may be read, written, or not accessed in this block.
83   // If it is both read and written, we just care about the first
84   // action (if it is read first, that's all the info we are
85   // looking for; if it is written first, it can't be read later).
86   enum LocalUse { Read, Written };
87   std::unordered_map<Index, LocalUse> localUses;
88 };
89 
90 struct DAEScanner
91   : public WalkerPass<
92       CFGWalker<DAEScanner, Visitor<DAEScanner>, DAEBlockInfo>> {
isFunctionParallelwasm::DAEScanner93   bool isFunctionParallel() override { return true; }
94 
createwasm::DAEScanner95   Pass* create() override { return new DAEScanner(infoMap); }
96 
DAEScannerwasm::DAEScanner97   DAEScanner(DAEFunctionInfoMap* infoMap) : infoMap(infoMap) {}
98 
99   DAEFunctionInfoMap* infoMap;
100   DAEFunctionInfo* info;
101 
102   Index numParams;
103 
104   // cfg traversal work
105 
visitLocalGetwasm::DAEScanner106   void visitLocalGet(LocalGet* curr) {
107     if (currBasicBlock) {
108       auto& localUses = currBasicBlock->contents.localUses;
109       auto index = curr->index;
110       if (localUses.count(index) == 0) {
111         localUses[index] = DAEBlockInfo::Read;
112       }
113     }
114   }
115 
visitLocalSetwasm::DAEScanner116   void visitLocalSet(LocalSet* curr) {
117     if (currBasicBlock) {
118       auto& localUses = currBasicBlock->contents.localUses;
119       auto index = curr->index;
120       if (localUses.count(index) == 0) {
121         localUses[index] = DAEBlockInfo::Written;
122       }
123     }
124   }
125 
visitCallwasm::DAEScanner126   void visitCall(Call* curr) {
127     if (!getModule()->getFunction(curr->target)->imported()) {
128       info->calls[curr->target].push_back(curr);
129     }
130     if (curr->isReturn) {
131       info->hasTailCalls = true;
132       info->tailCallees.insert(curr->target);
133     }
134   }
135 
visitCallIndirectwasm::DAEScanner136   void visitCallIndirect(CallIndirect* curr) {
137     if (curr->isReturn) {
138       info->hasTailCalls = true;
139     }
140   }
141 
visitDropwasm::DAEScanner142   void visitDrop(Drop* curr) {
143     if (auto* call = curr->value->dynCast<Call>()) {
144       info->droppedCalls[call] = getCurrentPointer();
145     }
146   }
147 
148   // main entry point
149 
doWalkFunctionwasm::DAEScanner150   void doWalkFunction(Function* func) {
151     numParams = func->getNumParams();
152     info = &((*infoMap)[func->name]);
153     CFGWalker<DAEScanner, Visitor<DAEScanner>, DAEBlockInfo>::doWalkFunction(
154       func);
155     // If there are relevant params, check if they are used. (If
156     // we can't optimize the function anyhow, there's no point.)
157     if (numParams > 0 && !info->hasUnseenCalls) {
158       findUnusedParams(func);
159     }
160   }
161 
findUnusedParamswasm::DAEScanner162   void findUnusedParams(Function* func) {
163     // Flow the incoming parameter values, see if they reach a read.
164     // Once we've seen a parameter at a block, we need never consider it there
165     // again.
166     std::unordered_map<BasicBlock*, SortedVector> seenBlockIndexes;
167     // Start with all the incoming parameters.
168     SortedVector initial;
169     for (Index i = 0; i < numParams; i++) {
170       initial.push_back(i);
171     }
172     // The used params, which we now compute.
173     std::unordered_set<Index> usedParams;
174     // An item of work is a block plus the values arriving there.
175     typedef std::pair<BasicBlock*, SortedVector> Item;
176     std::vector<Item> work;
177     work.emplace_back(entry, initial);
178     while (!work.empty()) {
179       auto item = std::move(work.back());
180       work.pop_back();
181       auto* block = item.first;
182       auto& indexes = item.second;
183       // Ignore things we've already seen, or we've already seen to be used.
184       auto& seenIndexes = seenBlockIndexes[block];
185       indexes.filter([&](const Index i) {
186         if (seenIndexes.has(i) || usedParams.count(i)) {
187           return false;
188         } else {
189           seenIndexes.insert(i);
190           return true;
191         }
192       });
193       if (indexes.empty()) {
194         continue; // nothing more to flow
195       }
196       auto& localUses = block->contents.localUses;
197       SortedVector remainingIndexes;
198       for (auto i : indexes) {
199         auto iter = localUses.find(i);
200         if (iter != localUses.end()) {
201           auto use = iter->second;
202           if (use == DAEBlockInfo::Read) {
203             usedParams.insert(i);
204           }
205           // Whether it was a read or a write, we can stop looking at that local
206           // here.
207         } else {
208           remainingIndexes.insert(i);
209         }
210       }
211       // If there are remaining indexes, flow them forward.
212       if (!remainingIndexes.empty()) {
213         for (auto* next : block->out) {
214           work.emplace_back(next, remainingIndexes);
215         }
216       }
217     }
218     // We can now compute the unused params.
219     for (Index i = 0; i < numParams; i++) {
220       if (usedParams.count(i) == 0) {
221         info->unusedParams.insert(i);
222       }
223     }
224   }
225 };
226 
227 struct DAE : public Pass {
228   bool optimize = false;
229 
runwasm::DAE230   void run(PassRunner* runner, Module* module) override {
231     // Iterate to convergence.
232     while (1) {
233       if (!iteration(runner, module)) {
234         break;
235       }
236     }
237   }
238 
iterationwasm::DAE239   bool iteration(PassRunner* runner, Module* module) {
240     allDroppedCalls.clear();
241 
242     DAEFunctionInfoMap infoMap;
243     // Ensure they all exist so the parallel threads don't modify the data
244     // structure.
245     ModuleUtils::iterDefinedFunctions(
246       *module, [&](Function* func) { infoMap[func->name]; });
247     // Check the influence of the table and exports.
248     for (auto& curr : module->exports) {
249       if (curr->kind == ExternalKind::Function) {
250         infoMap[curr->value].hasUnseenCalls = true;
251       }
252     }
253     for (auto& segment : module->table.segments) {
254       for (auto name : segment.data) {
255         infoMap[name].hasUnseenCalls = true;
256       }
257     }
258     // Scan all the functions.
259     DAEScanner(&infoMap).run(runner, module);
260     // Combine all the info.
261     std::unordered_map<Name, std::vector<Call*>> allCalls;
262     std::unordered_set<Name> tailCallees;
263     for (auto& pair : infoMap) {
264       auto& info = pair.second;
265       for (auto& pair : info.calls) {
266         auto name = pair.first;
267         auto& calls = pair.second;
268         auto& allCallsToName = allCalls[name];
269         allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
270       }
271       for (auto& callee : info.tailCallees) {
272         tailCallees.insert(callee);
273       }
274       for (auto& pair : info.droppedCalls) {
275         allDroppedCalls[pair.first] = pair.second;
276       }
277     }
278     // We now have a mapping of all call sites for each function. Check which
279     // are always passed the same constant for a particular argument.
280     for (auto& pair : allCalls) {
281       auto name = pair.first;
282       // We can only optimize if we see all the calls and can modify
283       // them.
284       if (infoMap[name].hasUnseenCalls) {
285         continue;
286       }
287       auto& calls = pair.second;
288       auto* func = module->getFunction(name);
289       auto numParams = func->getNumParams();
290       for (Index i = 0; i < numParams; i++) {
291         Literal value;
292         for (auto* call : calls) {
293           assert(call->target == name);
294           assert(call->operands.size() == numParams);
295           auto* operand = call->operands[i];
296           if (auto* c = operand->dynCast<Const>()) {
297             if (value.type == Type::none) {
298               // This is the first value seen.
299               value = c->value;
300             } else if (value != c->value) {
301               // Not identical, give up
302               value = Literal(Type::none);
303               break;
304             }
305           } else {
306             // Not a constant, give up
307             value = Literal(Type::none);
308             break;
309           }
310         }
311         if (value.type != Type::none) {
312           // Success! We can just apply the constant in the function, which
313           // makes the parameter value unused, which lets us remove it later.
314           Builder builder(*module);
315           func->body = builder.makeSequence(
316             builder.makeLocalSet(i, builder.makeConst(value)), func->body);
317           // Mark it as unused, which we know it now is (no point to
318           // re-scan just for that).
319           infoMap[name].unusedParams.insert(i);
320         }
321       }
322     }
323     // Track which functions we changed, and optimize them later if necessary.
324     std::unordered_set<Function*> changed;
325     // We now know which parameters are unused, and can potentially remove them.
326     for (auto& pair : allCalls) {
327       auto name = pair.first;
328       auto& calls = pair.second;
329       auto* func = module->getFunction(name);
330       auto numParams = func->getNumParams();
331       if (numParams == 0) {
332         continue;
333       }
334       // Iterate downwards, as we may remove more than one.
335       Index i = numParams - 1;
336       while (1) {
337         if (infoMap[name].unusedParams.has(i)) {
338           // Great, it's not used. Check if none of the calls has a param with
339           // side effects, as that would prevent us removing them (flattening
340           // should have been done earlier).
341           bool canRemove =
342             std::none_of(calls.begin(), calls.end(), [&](Call* call) {
343               auto* operand = call->operands[i];
344               return EffectAnalyzer(runner->options, module->features, operand)
345                 .hasSideEffects();
346             });
347           if (canRemove) {
348             // Wonderful, nothing stands in our way! Do it.
349             // TODO: parallelize this?
350             removeParameter(func, i, calls);
351             changed.insert(func);
352           }
353         }
354         if (i == 0) {
355           break;
356         }
357         i--;
358       }
359     }
360     // We can also tell which calls have all their return values dropped. Note
361     // that we can't do this if we changed anything so far, as we may have
362     // modified allCalls (we can't modify a call site twice in one iteration,
363     // once to remove a param, once to drop the return value).
364     if (changed.empty()) {
365       for (auto& func : module->functions) {
366         if (func->sig.results == Type::none) {
367           continue;
368         }
369         auto name = func->name;
370         if (infoMap[name].hasUnseenCalls) {
371           continue;
372         }
373         if (infoMap[name].hasTailCalls) {
374           continue;
375         }
376         if (tailCallees.find(name) != tailCallees.end()) {
377           continue;
378         }
379         auto iter = allCalls.find(name);
380         if (iter == allCalls.end()) {
381           continue;
382         }
383         auto& calls = iter->second;
384         bool allDropped =
385           std::all_of(calls.begin(), calls.end(), [&](Call* call) {
386             return allDroppedCalls.count(call);
387           });
388         if (!allDropped) {
389           continue;
390         }
391         removeReturnValue(func.get(), calls, module);
392         // TODO Removing a drop may also open optimization opportunities in the
393         // callers.
394         changed.insert(func.get());
395       }
396     }
397     if (optimize && !changed.empty()) {
398       OptUtils::optimizeAfterInlining(changed, module, runner);
399     }
400     return !changed.empty();
401   }
402 
403 private:
404   std::unordered_map<Call*, Expression**> allDroppedCalls;
405 
removeParameterwasm::DAE406   void removeParameter(Function* func, Index i, std::vector<Call*>& calls) {
407     // It's cumbersome to adjust local names - TODO don't clear them?
408     Builder::clearLocalNames(func);
409     // Remove the parameter from the function. We must add a new local
410     // for uses of the parameter, but cannot make it use the same index
411     // (in general).
412     std::vector<Type> params(func->sig.params.begin(), func->sig.params.end());
413     auto type = params[i];
414     params.erase(params.begin() + i);
415     func->sig.params = Type(params);
416     Index newIndex = Builder::addVar(func, type);
417     // Update local operations.
418     struct LocalUpdater : public PostWalker<LocalUpdater> {
419       Index removedIndex;
420       Index newIndex;
421       LocalUpdater(Function* func, Index removedIndex, Index newIndex)
422         : removedIndex(removedIndex), newIndex(newIndex) {
423         walk(func->body);
424       }
425       void visitLocalGet(LocalGet* curr) { updateIndex(curr->index); }
426       void visitLocalSet(LocalSet* curr) { updateIndex(curr->index); }
427       void updateIndex(Index& index) {
428         if (index == removedIndex) {
429           index = newIndex;
430         } else if (index > removedIndex) {
431           index--;
432         }
433       }
434     } localUpdater(func, i, newIndex);
435     // Remove the arguments from the calls.
436     for (auto* call : calls) {
437       call->operands.erase(call->operands.begin() + i);
438     }
439   }
440 
441   void
removeReturnValuewasm::DAE442   removeReturnValue(Function* func, std::vector<Call*>& calls, Module* module) {
443     func->sig.results = Type::none;
444     Builder builder(*module);
445     // Remove any return values.
446     struct ReturnUpdater : public PostWalker<ReturnUpdater> {
447       Module* module;
448       ReturnUpdater(Function* func, Module* module) : module(module) {
449         walk(func->body);
450       }
451       void visitReturn(Return* curr) {
452         auto* value = curr->value;
453         assert(value);
454         curr->value = nullptr;
455         Builder builder(*module);
456         replaceCurrent(builder.makeSequence(builder.makeDrop(value), curr));
457       }
458     } returnUpdater(func, module);
459     // Remove any value flowing out.
460     if (func->body->type.isConcrete()) {
461       func->body = builder.makeDrop(func->body);
462     }
463     // Remove the drops on the calls.
464     for (auto* call : calls) {
465       auto iter = allDroppedCalls.find(call);
466       assert(iter != allDroppedCalls.end());
467       Expression** location = iter->second;
468       *location = call;
469       // Update the call's type.
470       if (call->type != Type::unreachable) {
471         call->type = Type::none;
472       }
473     }
474   }
475 };
476 
createDAEPass()477 Pass* createDAEPass() { return new DAE(); }
478 
createDAEOptimizingPass()479 Pass* createDAEOptimizingPass() {
480   auto* ret = new DAE();
481   ret->optimize = true;
482   return ret;
483 }
484 
485 } // namespace wasm
486