1 /*
2 * Copyright 2018 WebAssembly Community Group participants
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 //
18 // Optimizes call arguments in a whole-program manner, removing ones
19 // that are not used (dead).
20 //
21 // Specifically, this does these things:
22 //
23 // * Find functions for whom an argument is always passed the same
24 // constant. If so, we can just set that local to that constant
25 // in the function.
26 // * Find functions that don't use the value passed to an argument.
27 // If so, we can avoid even sending and receiving it. (Note how if
28 // the previous point was true for an argument, then the second
29 // must as well.)
30 // * Find return values ("return arguments" ;) that are never used.
31 //
32 // This pass does not depend on flattening, but it may be more effective,
33 // as then call arguments never have side effects (which we need to
34 // watch for here).
35 //
36
37 #include <unordered_map>
38 #include <unordered_set>
39
40 #include "cfg/cfg-traversal.h"
41 #include "ir/effects.h"
42 #include "ir/module-utils.h"
43 #include "pass.h"
44 #include "passes/opt-utils.h"
45 #include "support/sorted_vector.h"
46 #include "wasm-builder.h"
47 #include "wasm.h"
48
49 namespace wasm {
50
51 // Information for a function
52 struct DAEFunctionInfo {
53 // The unused parameters, if any.
54 SortedVector unusedParams;
55 // Maps a function name to the calls going to it.
56 std::unordered_map<Name, std::vector<Call*>> calls;
57 // Map of all calls that are dropped, to their drops' locations (so that
58 // if we can optimize out the drop, we can replace the drop there).
59 std::unordered_map<Call*, Expression**> droppedCalls;
60 // Whether this function contains any tail calls (including indirect tail
61 // calls) and the set of functions this function tail calls. Tail-callers and
62 // tail-callees cannot have their dropped returns removed because of the
63 // constraint that tail-callees must have the same return type as
64 // tail-callers. Indirectly tail called functions are already not optimized
65 // because being in a table inhibits DAE. TODO: Allow the removal of dropped
66 // returns from tail-callers if their tail-callees can have their returns
67 // removed as well.
68 bool hasTailCalls = false;
69 std::unordered_set<Name> tailCallees;
70 // Whether the function can be called from places that
71 // affect what we can do. For now, any call we don't
72 // see inhibits our optimizations, but TODO: an export
73 // could be worked around by exporting a thunk that
74 // adds the parameter.
75 bool hasUnseenCalls = false;
76 };
77
78 typedef std::unordered_map<Name, DAEFunctionInfo> DAEFunctionInfoMap;
79
80 // Information in a basic block
81 struct DAEBlockInfo {
82 // A local may be read, written, or not accessed in this block.
83 // If it is both read and written, we just care about the first
84 // action (if it is read first, that's all the info we are
85 // looking for; if it is written first, it can't be read later).
86 enum LocalUse { Read, Written };
87 std::unordered_map<Index, LocalUse> localUses;
88 };
89
90 struct DAEScanner
91 : public WalkerPass<
92 CFGWalker<DAEScanner, Visitor<DAEScanner>, DAEBlockInfo>> {
isFunctionParallelwasm::DAEScanner93 bool isFunctionParallel() override { return true; }
94
createwasm::DAEScanner95 Pass* create() override { return new DAEScanner(infoMap); }
96
DAEScannerwasm::DAEScanner97 DAEScanner(DAEFunctionInfoMap* infoMap) : infoMap(infoMap) {}
98
99 DAEFunctionInfoMap* infoMap;
100 DAEFunctionInfo* info;
101
102 Index numParams;
103
104 // cfg traversal work
105
visitLocalGetwasm::DAEScanner106 void visitLocalGet(LocalGet* curr) {
107 if (currBasicBlock) {
108 auto& localUses = currBasicBlock->contents.localUses;
109 auto index = curr->index;
110 if (localUses.count(index) == 0) {
111 localUses[index] = DAEBlockInfo::Read;
112 }
113 }
114 }
115
visitLocalSetwasm::DAEScanner116 void visitLocalSet(LocalSet* curr) {
117 if (currBasicBlock) {
118 auto& localUses = currBasicBlock->contents.localUses;
119 auto index = curr->index;
120 if (localUses.count(index) == 0) {
121 localUses[index] = DAEBlockInfo::Written;
122 }
123 }
124 }
125
visitCallwasm::DAEScanner126 void visitCall(Call* curr) {
127 if (!getModule()->getFunction(curr->target)->imported()) {
128 info->calls[curr->target].push_back(curr);
129 }
130 if (curr->isReturn) {
131 info->hasTailCalls = true;
132 info->tailCallees.insert(curr->target);
133 }
134 }
135
visitCallIndirectwasm::DAEScanner136 void visitCallIndirect(CallIndirect* curr) {
137 if (curr->isReturn) {
138 info->hasTailCalls = true;
139 }
140 }
141
visitDropwasm::DAEScanner142 void visitDrop(Drop* curr) {
143 if (auto* call = curr->value->dynCast<Call>()) {
144 info->droppedCalls[call] = getCurrentPointer();
145 }
146 }
147
148 // main entry point
149
doWalkFunctionwasm::DAEScanner150 void doWalkFunction(Function* func) {
151 numParams = func->getNumParams();
152 info = &((*infoMap)[func->name]);
153 CFGWalker<DAEScanner, Visitor<DAEScanner>, DAEBlockInfo>::doWalkFunction(
154 func);
155 // If there are relevant params, check if they are used. (If
156 // we can't optimize the function anyhow, there's no point.)
157 if (numParams > 0 && !info->hasUnseenCalls) {
158 findUnusedParams(func);
159 }
160 }
161
findUnusedParamswasm::DAEScanner162 void findUnusedParams(Function* func) {
163 // Flow the incoming parameter values, see if they reach a read.
164 // Once we've seen a parameter at a block, we need never consider it there
165 // again.
166 std::unordered_map<BasicBlock*, SortedVector> seenBlockIndexes;
167 // Start with all the incoming parameters.
168 SortedVector initial;
169 for (Index i = 0; i < numParams; i++) {
170 initial.push_back(i);
171 }
172 // The used params, which we now compute.
173 std::unordered_set<Index> usedParams;
174 // An item of work is a block plus the values arriving there.
175 typedef std::pair<BasicBlock*, SortedVector> Item;
176 std::vector<Item> work;
177 work.emplace_back(entry, initial);
178 while (!work.empty()) {
179 auto item = std::move(work.back());
180 work.pop_back();
181 auto* block = item.first;
182 auto& indexes = item.second;
183 // Ignore things we've already seen, or we've already seen to be used.
184 auto& seenIndexes = seenBlockIndexes[block];
185 indexes.filter([&](const Index i) {
186 if (seenIndexes.has(i) || usedParams.count(i)) {
187 return false;
188 } else {
189 seenIndexes.insert(i);
190 return true;
191 }
192 });
193 if (indexes.empty()) {
194 continue; // nothing more to flow
195 }
196 auto& localUses = block->contents.localUses;
197 SortedVector remainingIndexes;
198 for (auto i : indexes) {
199 auto iter = localUses.find(i);
200 if (iter != localUses.end()) {
201 auto use = iter->second;
202 if (use == DAEBlockInfo::Read) {
203 usedParams.insert(i);
204 }
205 // Whether it was a read or a write, we can stop looking at that local
206 // here.
207 } else {
208 remainingIndexes.insert(i);
209 }
210 }
211 // If there are remaining indexes, flow them forward.
212 if (!remainingIndexes.empty()) {
213 for (auto* next : block->out) {
214 work.emplace_back(next, remainingIndexes);
215 }
216 }
217 }
218 // We can now compute the unused params.
219 for (Index i = 0; i < numParams; i++) {
220 if (usedParams.count(i) == 0) {
221 info->unusedParams.insert(i);
222 }
223 }
224 }
225 };
226
227 struct DAE : public Pass {
228 bool optimize = false;
229
runwasm::DAE230 void run(PassRunner* runner, Module* module) override {
231 // Iterate to convergence.
232 while (1) {
233 if (!iteration(runner, module)) {
234 break;
235 }
236 }
237 }
238
iterationwasm::DAE239 bool iteration(PassRunner* runner, Module* module) {
240 allDroppedCalls.clear();
241
242 DAEFunctionInfoMap infoMap;
243 // Ensure they all exist so the parallel threads don't modify the data
244 // structure.
245 ModuleUtils::iterDefinedFunctions(
246 *module, [&](Function* func) { infoMap[func->name]; });
247 // Check the influence of the table and exports.
248 for (auto& curr : module->exports) {
249 if (curr->kind == ExternalKind::Function) {
250 infoMap[curr->value].hasUnseenCalls = true;
251 }
252 }
253 for (auto& segment : module->table.segments) {
254 for (auto name : segment.data) {
255 infoMap[name].hasUnseenCalls = true;
256 }
257 }
258 // Scan all the functions.
259 DAEScanner(&infoMap).run(runner, module);
260 // Combine all the info.
261 std::unordered_map<Name, std::vector<Call*>> allCalls;
262 std::unordered_set<Name> tailCallees;
263 for (auto& pair : infoMap) {
264 auto& info = pair.second;
265 for (auto& pair : info.calls) {
266 auto name = pair.first;
267 auto& calls = pair.second;
268 auto& allCallsToName = allCalls[name];
269 allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
270 }
271 for (auto& callee : info.tailCallees) {
272 tailCallees.insert(callee);
273 }
274 for (auto& pair : info.droppedCalls) {
275 allDroppedCalls[pair.first] = pair.second;
276 }
277 }
278 // We now have a mapping of all call sites for each function. Check which
279 // are always passed the same constant for a particular argument.
280 for (auto& pair : allCalls) {
281 auto name = pair.first;
282 // We can only optimize if we see all the calls and can modify
283 // them.
284 if (infoMap[name].hasUnseenCalls) {
285 continue;
286 }
287 auto& calls = pair.second;
288 auto* func = module->getFunction(name);
289 auto numParams = func->getNumParams();
290 for (Index i = 0; i < numParams; i++) {
291 Literal value;
292 for (auto* call : calls) {
293 assert(call->target == name);
294 assert(call->operands.size() == numParams);
295 auto* operand = call->operands[i];
296 if (auto* c = operand->dynCast<Const>()) {
297 if (value.type == Type::none) {
298 // This is the first value seen.
299 value = c->value;
300 } else if (value != c->value) {
301 // Not identical, give up
302 value = Literal(Type::none);
303 break;
304 }
305 } else {
306 // Not a constant, give up
307 value = Literal(Type::none);
308 break;
309 }
310 }
311 if (value.type != Type::none) {
312 // Success! We can just apply the constant in the function, which
313 // makes the parameter value unused, which lets us remove it later.
314 Builder builder(*module);
315 func->body = builder.makeSequence(
316 builder.makeLocalSet(i, builder.makeConst(value)), func->body);
317 // Mark it as unused, which we know it now is (no point to
318 // re-scan just for that).
319 infoMap[name].unusedParams.insert(i);
320 }
321 }
322 }
323 // Track which functions we changed, and optimize them later if necessary.
324 std::unordered_set<Function*> changed;
325 // We now know which parameters are unused, and can potentially remove them.
326 for (auto& pair : allCalls) {
327 auto name = pair.first;
328 auto& calls = pair.second;
329 auto* func = module->getFunction(name);
330 auto numParams = func->getNumParams();
331 if (numParams == 0) {
332 continue;
333 }
334 // Iterate downwards, as we may remove more than one.
335 Index i = numParams - 1;
336 while (1) {
337 if (infoMap[name].unusedParams.has(i)) {
338 // Great, it's not used. Check if none of the calls has a param with
339 // side effects, as that would prevent us removing them (flattening
340 // should have been done earlier).
341 bool canRemove =
342 std::none_of(calls.begin(), calls.end(), [&](Call* call) {
343 auto* operand = call->operands[i];
344 return EffectAnalyzer(runner->options, module->features, operand)
345 .hasSideEffects();
346 });
347 if (canRemove) {
348 // Wonderful, nothing stands in our way! Do it.
349 // TODO: parallelize this?
350 removeParameter(func, i, calls);
351 changed.insert(func);
352 }
353 }
354 if (i == 0) {
355 break;
356 }
357 i--;
358 }
359 }
360 // We can also tell which calls have all their return values dropped. Note
361 // that we can't do this if we changed anything so far, as we may have
362 // modified allCalls (we can't modify a call site twice in one iteration,
363 // once to remove a param, once to drop the return value).
364 if (changed.empty()) {
365 for (auto& func : module->functions) {
366 if (func->sig.results == Type::none) {
367 continue;
368 }
369 auto name = func->name;
370 if (infoMap[name].hasUnseenCalls) {
371 continue;
372 }
373 if (infoMap[name].hasTailCalls) {
374 continue;
375 }
376 if (tailCallees.find(name) != tailCallees.end()) {
377 continue;
378 }
379 auto iter = allCalls.find(name);
380 if (iter == allCalls.end()) {
381 continue;
382 }
383 auto& calls = iter->second;
384 bool allDropped =
385 std::all_of(calls.begin(), calls.end(), [&](Call* call) {
386 return allDroppedCalls.count(call);
387 });
388 if (!allDropped) {
389 continue;
390 }
391 removeReturnValue(func.get(), calls, module);
392 // TODO Removing a drop may also open optimization opportunities in the
393 // callers.
394 changed.insert(func.get());
395 }
396 }
397 if (optimize && !changed.empty()) {
398 OptUtils::optimizeAfterInlining(changed, module, runner);
399 }
400 return !changed.empty();
401 }
402
403 private:
404 std::unordered_map<Call*, Expression**> allDroppedCalls;
405
removeParameterwasm::DAE406 void removeParameter(Function* func, Index i, std::vector<Call*>& calls) {
407 // It's cumbersome to adjust local names - TODO don't clear them?
408 Builder::clearLocalNames(func);
409 // Remove the parameter from the function. We must add a new local
410 // for uses of the parameter, but cannot make it use the same index
411 // (in general).
412 std::vector<Type> params(func->sig.params.begin(), func->sig.params.end());
413 auto type = params[i];
414 params.erase(params.begin() + i);
415 func->sig.params = Type(params);
416 Index newIndex = Builder::addVar(func, type);
417 // Update local operations.
418 struct LocalUpdater : public PostWalker<LocalUpdater> {
419 Index removedIndex;
420 Index newIndex;
421 LocalUpdater(Function* func, Index removedIndex, Index newIndex)
422 : removedIndex(removedIndex), newIndex(newIndex) {
423 walk(func->body);
424 }
425 void visitLocalGet(LocalGet* curr) { updateIndex(curr->index); }
426 void visitLocalSet(LocalSet* curr) { updateIndex(curr->index); }
427 void updateIndex(Index& index) {
428 if (index == removedIndex) {
429 index = newIndex;
430 } else if (index > removedIndex) {
431 index--;
432 }
433 }
434 } localUpdater(func, i, newIndex);
435 // Remove the arguments from the calls.
436 for (auto* call : calls) {
437 call->operands.erase(call->operands.begin() + i);
438 }
439 }
440
441 void
removeReturnValuewasm::DAE442 removeReturnValue(Function* func, std::vector<Call*>& calls, Module* module) {
443 func->sig.results = Type::none;
444 Builder builder(*module);
445 // Remove any return values.
446 struct ReturnUpdater : public PostWalker<ReturnUpdater> {
447 Module* module;
448 ReturnUpdater(Function* func, Module* module) : module(module) {
449 walk(func->body);
450 }
451 void visitReturn(Return* curr) {
452 auto* value = curr->value;
453 assert(value);
454 curr->value = nullptr;
455 Builder builder(*module);
456 replaceCurrent(builder.makeSequence(builder.makeDrop(value), curr));
457 }
458 } returnUpdater(func, module);
459 // Remove any value flowing out.
460 if (func->body->type.isConcrete()) {
461 func->body = builder.makeDrop(func->body);
462 }
463 // Remove the drops on the calls.
464 for (auto* call : calls) {
465 auto iter = allDroppedCalls.find(call);
466 assert(iter != allDroppedCalls.end());
467 Expression** location = iter->second;
468 *location = call;
469 // Update the call's type.
470 if (call->type != Type::unreachable) {
471 call->type = Type::none;
472 }
473 }
474 }
475 };
476
createDAEPass()477 Pass* createDAEPass() { return new DAE(); }
478
createDAEOptimizingPass()479 Pass* createDAEOptimizingPass() {
480 auto* ret = new DAE();
481 ret->optimize = true;
482 return ret;
483 }
484
485 } // namespace wasm
486