1 /*
2  * Copyright 2016 WebAssembly Community Group participants
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //
18 // Removes dead, i.e. unreachable, code.
19 //
20 // We keep a record of when control flow is reachable. When it isn't, we
21 // kill (turn into unreachable). We then fold away entire unreachable
22 // expressions.
23 //
24 // When dead code causes an operation to not happen, like a store, a call
25 // or an add, we replace with a block with a list of what does happen.
26 // That isn't necessarily smaller, but blocks are friendlier to other
27 // optimizations: blocks can be merged and eliminated, and they clearly
28 // have no side effects.
29 //
30 
31 #include <ir/block-utils.h>
32 #include <ir/branch-utils.h>
33 #include <ir/type-updating.h>
34 #include <pass.h>
35 #include <vector>
36 #include <wasm-builder.h>
37 #include <wasm.h>
38 
39 namespace wasm {
40 
41 struct DeadCodeElimination
42   : public WalkerPass<PostWalker<DeadCodeElimination>> {
isFunctionParallelwasm::DeadCodeElimination43   bool isFunctionParallel() override { return true; }
44 
createwasm::DeadCodeElimination45   Pass* create() override { return new DeadCodeElimination; }
46 
47   // as we remove code, we must keep the types of other nodes valid
48   TypeUpdater typeUpdater;
49 
replaceCurrentwasm::DeadCodeElimination50   Expression* replaceCurrent(Expression* expression) {
51     auto* old = getCurrent();
52     if (old == expression) {
53       return expression;
54     }
55     super::replaceCurrent(expression);
56     // also update the type updater
57     typeUpdater.noteReplacement(old, expression);
58     return expression;
59   }
60 
61   // whether the current code is actually reachable
62   bool reachable;
63 
doWalkFunctionwasm::DeadCodeElimination64   void doWalkFunction(Function* func) {
65     reachable = true;
66     typeUpdater.walk(func->body);
67     walk(func->body);
68   }
69 
70   std::set<Name> reachableBreaks;
71 
addBreakwasm::DeadCodeElimination72   void addBreak(Name name) {
73     // we normally have already reduced unreachable code into (unreachable)
74     // nodes, so we would not get to this place at all anyhow, the breaking
75     // instruction itself would be removed. However, an exception are things
76     // like  (block (result i32) (call $x) (unreachable)) , which has type i32
77     // despite not being exited.
78     // TODO: optimize such cases
79     if (reachable) {
80       reachableBreaks.insert(name);
81     }
82   }
83 
84   // if a child exists and is unreachable, we can replace ourselves with it
isDeadwasm::DeadCodeElimination85   bool isDead(Expression* child) {
86     return child && child->type == Type::unreachable;
87   }
88 
89   // a similar check, assumes the child exists
isUnreachablewasm::DeadCodeElimination90   bool isUnreachable(Expression* child) {
91     return child->type == Type::unreachable;
92   }
93 
94   // things that stop control flow
95 
visitBreakwasm::DeadCodeElimination96   void visitBreak(Break* curr) {
97     if (isDead(curr->value)) {
98       // the condition is evaluated last, so if the value was unreachable, the
99       // whole thing is
100       replaceCurrent(curr->value);
101       return;
102     }
103     if (isDead(curr->condition)) {
104       if (curr->value) {
105         auto* block = getModule()->allocator.alloc<Block>();
106         block->list.resize(2);
107         block->list[0] = drop(curr->value);
108         block->list[1] = curr->condition;
109         // if we previously returned a value, then this block
110         // must have the same type, so it fits in the ast
111         // properly. it ends in an unreachable
112         // anyhow, so that is ok.
113         block->finalize(curr->type);
114         replaceCurrent(block);
115       } else {
116         replaceCurrent(curr->condition);
117       }
118       return;
119     }
120     addBreak(curr->name);
121     if (!curr->condition) {
122       reachable = false;
123     }
124   }
125 
visitSwitchwasm::DeadCodeElimination126   void visitSwitch(Switch* curr) {
127     if (isDead(curr->value)) {
128       replaceCurrent(curr->value);
129       return;
130     }
131     if (isUnreachable(curr->condition)) {
132       if (curr->value) {
133         auto* block = getModule()->allocator.alloc<Block>();
134         block->list.resize(2);
135         block->list[0] = drop(curr->value);
136         block->list[1] = curr->condition;
137         block->finalize(curr->type);
138         replaceCurrent(block);
139       } else {
140         replaceCurrent(curr->condition);
141       }
142       return;
143     }
144     for (auto target : curr->targets) {
145       addBreak(target);
146     }
147     addBreak(curr->default_);
148     reachable = false;
149   }
150 
visitReturnwasm::DeadCodeElimination151   void visitReturn(Return* curr) {
152     if (isDead(curr->value)) {
153       replaceCurrent(curr->value);
154       return;
155     }
156     reachable = false;
157   }
158 
visitUnreachablewasm::DeadCodeElimination159   void visitUnreachable(Unreachable* curr) { reachable = false; }
160 
visitBlockwasm::DeadCodeElimination161   void visitBlock(Block* curr) {
162     auto& list = curr->list;
163     // if we are currently unreachable (before we take into account
164     // breaks to the block) then a child may be unreachable, and we
165     // can shorten
166     if (!reachable && list.size() > 1) {
167       // to do here: nothing to remove after it)
168       for (Index i = 0; i < list.size() - 1; i++) {
169         if (list[i]->type == Type::unreachable) {
170           list.resize(i + 1);
171           break;
172         }
173       }
174     }
175     if (curr->name.is()) {
176       reachable = reachable || reachableBreaks.count(curr->name);
177       reachableBreaks.erase(curr->name);
178     }
179     if (list.size() == 1 && isUnreachable(list[0])) {
180       replaceCurrent(
181         BlockUtils::simplifyToContentsWithPossibleTypeChange(curr, this));
182     } else {
183       // the block may have had a type, but can now be unreachable, which allows
184       // more reduction outside
185       typeUpdater.maybeUpdateTypeToUnreachable(curr);
186     }
187   }
188 
visitLoopwasm::DeadCodeElimination189   void visitLoop(Loop* curr) {
190     if (curr->name.is()) {
191       reachableBreaks.erase(curr->name);
192     }
193     if (isUnreachable(curr->body) &&
194         !BranchUtils::BranchSeeker::has(curr->body, curr->name)) {
195       replaceCurrent(curr->body);
196       return;
197     }
198   }
199 
200   // ifs and trys need special handling: only one of (if body and else body /
201   // try body and catch body) should be reachable to make the whole of (if /
202   // try) to be reachable.
203 
204   // stack of reachable state, for forking and joining
205   std::vector<bool> ifStack;
206   std::vector<bool> tryStack;
207 
doAfterIfConditionwasm::DeadCodeElimination208   static void doAfterIfCondition(DeadCodeElimination* self,
209                                  Expression** currp) {
210     self->ifStack.push_back(self->reachable);
211   }
212 
doAfterIfElseTruewasm::DeadCodeElimination213   static void doAfterIfElseTrue(DeadCodeElimination* self, Expression** currp) {
214     assert((*currp)->cast<If>()->ifFalse);
215     bool reachableBefore = self->ifStack.back();
216     self->ifStack.pop_back();
217     self->ifStack.push_back(self->reachable);
218     self->reachable = reachableBefore;
219   }
220 
visitIfwasm::DeadCodeElimination221   void visitIf(If* curr) {
222     // the ifStack has the branch that joins us, either from before if just an
223     // if, or the ifTrue if an if-else
224     reachable = reachable || ifStack.back();
225     ifStack.pop_back();
226     if (isUnreachable(curr->condition)) {
227       replaceCurrent(curr->condition);
228     }
229     // the if may have had a type, but can now be unreachable, which allows more
230     // reduction outside
231     typeUpdater.maybeUpdateTypeToUnreachable(curr);
232   }
233 
doBeforeTryBodywasm::DeadCodeElimination234   static void doBeforeTryBody(DeadCodeElimination* self, Expression** currp) {
235     self->tryStack.push_back(self->reachable);
236   }
237 
doAfterTryBodywasm::DeadCodeElimination238   static void doAfterTryBody(DeadCodeElimination* self, Expression** currp) {
239     bool reachableBefore = self->tryStack.back();
240     self->tryStack.pop_back();
241     self->tryStack.push_back(self->reachable);
242     self->reachable = reachableBefore;
243   }
244 
visitTrywasm::DeadCodeElimination245   void visitTry(Try* curr) {
246     // the tryStack has the branch that joins us
247     reachable = reachable || tryStack.back();
248     tryStack.pop_back();
249     // the try may have had a type, but can now be unreachable, which allows
250     // more reduction outside
251     typeUpdater.maybeUpdateTypeToUnreachable(curr);
252   }
253 
visitThrowwasm::DeadCodeElimination254   void visitThrow(Throw* curr) { reachable = false; }
255 
visitRethrowwasm::DeadCodeElimination256   void visitRethrow(Rethrow* curr) { reachable = false; }
257 
visitBrOnExnwasm::DeadCodeElimination258   void visitBrOnExn(BrOnExn* curr) {
259     if (isDead(curr->exnref)) {
260       replaceCurrent(curr->exnref);
261       return;
262     }
263     addBreak(curr->name);
264   }
265 
scanwasm::DeadCodeElimination266   static void scan(DeadCodeElimination* self, Expression** currp) {
267     auto* curr = *currp;
268     if (!self->reachable) {
269 // convert to an unreachable safely
270 #define DELEGATE(CLASS_TO_VISIT)                                               \
271   {                                                                            \
272     auto* parent = self->typeUpdater.parents[curr];                            \
273     self->typeUpdater.noteRecursiveRemoval(curr);                              \
274     ExpressionManipulator::convert<CLASS_TO_VISIT, Unreachable>(               \
275       static_cast<CLASS_TO_VISIT*>(curr));                                     \
276     self->typeUpdater.noteAddition(curr, parent);                              \
277     break;                                                                     \
278   }
279       switch (curr->_id) {
280         case Expression::Id::BlockId:
281           DELEGATE(Block);
282         case Expression::Id::IfId:
283           DELEGATE(If);
284         case Expression::Id::LoopId:
285           DELEGATE(Loop);
286         case Expression::Id::BreakId:
287           DELEGATE(Break);
288         case Expression::Id::SwitchId:
289           DELEGATE(Switch);
290         case Expression::Id::CallId:
291           DELEGATE(Call);
292         case Expression::Id::CallIndirectId:
293           DELEGATE(CallIndirect);
294         case Expression::Id::LocalGetId:
295           DELEGATE(LocalGet);
296         case Expression::Id::LocalSetId:
297           DELEGATE(LocalSet);
298         case Expression::Id::GlobalGetId:
299           DELEGATE(GlobalGet);
300         case Expression::Id::GlobalSetId:
301           DELEGATE(GlobalSet);
302         case Expression::Id::LoadId:
303           DELEGATE(Load);
304         case Expression::Id::StoreId:
305           DELEGATE(Store);
306         case Expression::Id::ConstId:
307           DELEGATE(Const);
308         case Expression::Id::UnaryId:
309           DELEGATE(Unary);
310         case Expression::Id::BinaryId:
311           DELEGATE(Binary);
312         case Expression::Id::SelectId:
313           DELEGATE(Select);
314         case Expression::Id::DropId:
315           DELEGATE(Drop);
316         case Expression::Id::ReturnId:
317           DELEGATE(Return);
318         case Expression::Id::MemorySizeId:
319           DELEGATE(MemorySize);
320         case Expression::Id::MemoryGrowId:
321           DELEGATE(MemoryGrow);
322         case Expression::Id::NopId:
323           DELEGATE(Nop);
324         case Expression::Id::UnreachableId:
325           break;
326         case Expression::Id::AtomicCmpxchgId:
327           DELEGATE(AtomicCmpxchg);
328         case Expression::Id::AtomicRMWId:
329           DELEGATE(AtomicRMW);
330         case Expression::Id::AtomicWaitId:
331           DELEGATE(AtomicWait);
332         case Expression::Id::AtomicNotifyId:
333           DELEGATE(AtomicNotify);
334         case Expression::Id::AtomicFenceId:
335           DELEGATE(AtomicFence);
336         case Expression::Id::SIMDExtractId:
337           DELEGATE(SIMDExtract);
338         case Expression::Id::SIMDReplaceId:
339           DELEGATE(SIMDReplace);
340         case Expression::Id::SIMDShuffleId:
341           DELEGATE(SIMDShuffle);
342         case Expression::Id::SIMDTernaryId:
343           DELEGATE(SIMDTernary);
344         case Expression::Id::SIMDShiftId:
345           DELEGATE(SIMDShift);
346         case Expression::Id::SIMDLoadId:
347           DELEGATE(SIMDLoad);
348         case Expression::Id::MemoryInitId:
349           DELEGATE(MemoryInit);
350         case Expression::Id::DataDropId:
351           DELEGATE(DataDrop);
352         case Expression::Id::MemoryCopyId:
353           DELEGATE(MemoryCopy);
354         case Expression::Id::MemoryFillId:
355           DELEGATE(MemoryFill);
356         case Expression::Id::PopId:
357           DELEGATE(Pop);
358         case Expression::Id::RefNullId:
359           DELEGATE(RefNull);
360         case Expression::Id::RefIsNullId:
361           DELEGATE(RefIsNull);
362         case Expression::Id::RefFuncId:
363           DELEGATE(RefFunc);
364         case Expression::Id::RefEqId:
365           DELEGATE(RefEq);
366         case Expression::Id::TryId:
367           DELEGATE(Try);
368         case Expression::Id::ThrowId:
369           DELEGATE(Throw);
370         case Expression::Id::RethrowId:
371           DELEGATE(Rethrow);
372         case Expression::Id::BrOnExnId:
373           DELEGATE(BrOnExn);
374         case Expression::Id::TupleMakeId:
375           DELEGATE(TupleMake);
376         case Expression::Id::TupleExtractId:
377           DELEGATE(TupleExtract);
378         case Expression::Id::I31NewId:
379           DELEGATE(I31New);
380         case Expression::Id::I31GetId:
381           DELEGATE(I31Get);
382         case Expression::Id::RefTestId:
383           DELEGATE(RefTest);
384         case Expression::Id::RefCastId:
385           DELEGATE(RefCast);
386         case Expression::Id::BrOnCastId:
387           DELEGATE(BrOnCast);
388         case Expression::Id::RttCanonId:
389           DELEGATE(RttCanon);
390         case Expression::Id::RttSubId:
391           DELEGATE(RttSub);
392         case Expression::Id::StructNewId:
393           DELEGATE(StructNew);
394         case Expression::Id::StructGetId:
395           DELEGATE(StructGet);
396         case Expression::Id::StructSetId:
397           DELEGATE(StructSet);
398         case Expression::Id::ArrayNewId:
399           DELEGATE(ArrayNew);
400         case Expression::Id::ArrayGetId:
401           DELEGATE(ArrayGet);
402         case Expression::Id::ArraySetId:
403           DELEGATE(ArraySet);
404         case Expression::Id::ArrayLenId:
405           DELEGATE(ArrayLen);
406         case Expression::Id::InvalidId:
407           WASM_UNREACHABLE("unimp");
408         case Expression::Id::NumExpressionIds:
409           WASM_UNREACHABLE("unimp");
410       }
411 #undef DELEGATE
412       return;
413     }
414     if (curr->is<If>()) {
415       self->pushTask(DeadCodeElimination::doVisitIf, currp);
416       if (curr->cast<If>()->ifFalse) {
417         self->pushTask(DeadCodeElimination::scan, &curr->cast<If>()->ifFalse);
418         self->pushTask(DeadCodeElimination::doAfterIfElseTrue, currp);
419       }
420       self->pushTask(DeadCodeElimination::scan, &curr->cast<If>()->ifTrue);
421       self->pushTask(DeadCodeElimination::doAfterIfCondition, currp);
422       self->pushTask(DeadCodeElimination::scan, &curr->cast<If>()->condition);
423     } else if (curr->is<Try>()) {
424       self->pushTask(DeadCodeElimination::doVisitTry, currp);
425       self->pushTask(DeadCodeElimination::scan, &curr->cast<Try>()->catchBody);
426       self->pushTask(DeadCodeElimination::doAfterTryBody, currp);
427       self->pushTask(DeadCodeElimination::scan, &curr->cast<Try>()->body);
428       self->pushTask(DeadCodeElimination::doBeforeTryBody, currp);
429     } else {
430       super::scan(self, currp);
431     }
432   }
433 
434   // other things
435 
436   // we don't need to drop unreachable nodes
dropwasm::DeadCodeElimination437   Expression* drop(Expression* toDrop) {
438     if (toDrop->type == Type::unreachable) {
439       return toDrop;
440     }
441     return Builder(*getModule()).makeDrop(toDrop);
442   }
443 
handleCallwasm::DeadCodeElimination444   template<typename T> Expression* handleCall(T* curr) {
445     for (Index i = 0; i < curr->operands.size(); i++) {
446       if (isUnreachable(curr->operands[i])) {
447         if (i > 0) {
448           auto* block = getModule()->allocator.alloc<Block>();
449           Index newSize = i + 1;
450           block->list.resize(newSize);
451           Index j = 0;
452           for (; j < newSize; j++) {
453             block->list[j] = drop(curr->operands[j]);
454           }
455           block->finalize(curr->type);
456           return replaceCurrent(block);
457         } else {
458           return replaceCurrent(curr->operands[i]);
459         }
460       }
461     }
462     return curr;
463   }
464 
visitCallwasm::DeadCodeElimination465   void visitCall(Call* curr) {
466     handleCall(curr);
467     if (curr->isReturn) {
468       reachable = false;
469     }
470   }
471 
visitCallIndirectwasm::DeadCodeElimination472   void visitCallIndirect(CallIndirect* curr) {
473     if (handleCall(curr) != curr) {
474       return;
475     }
476     if (isUnreachable(curr->target)) {
477       auto* block = getModule()->allocator.alloc<Block>();
478       for (auto* operand : curr->operands) {
479         block->list.push_back(drop(operand));
480       }
481       block->list.push_back(curr->target);
482       block->finalize(curr->type);
483       replaceCurrent(block);
484     }
485     if (curr->isReturn) {
486       reachable = false;
487     }
488   }
489 
490   // Append the reachable operands of the current node to a block, and replace
491   // it with the block
blockifyReachableOperandswasm::DeadCodeElimination492   void blockifyReachableOperands(std::vector<Expression*>&& list, Type type) {
493     for (size_t i = 0; i < list.size(); ++i) {
494       auto* elem = list[i];
495       if (isUnreachable(elem)) {
496         auto* replacement = elem;
497         if (i > 0) {
498           auto* block = getModule()->allocator.alloc<Block>();
499           for (size_t j = 0; j < i; ++j) {
500             block->list.push_back(drop(list[j]));
501           }
502           block->list.push_back(list[i]);
503           block->finalize(type);
504           replacement = block;
505         }
506         replaceCurrent(replacement);
507         return;
508       }
509     }
510   }
511 
visitLocalSetwasm::DeadCodeElimination512   void visitLocalSet(LocalSet* curr) {
513     blockifyReachableOperands({curr->value}, curr->type);
514   }
515 
visitGlobalSetwasm::DeadCodeElimination516   void visitGlobalSet(GlobalSet* curr) {
517     blockifyReachableOperands({curr->value}, curr->type);
518   }
519 
visitLoadwasm::DeadCodeElimination520   void visitLoad(Load* curr) {
521     blockifyReachableOperands({curr->ptr}, curr->type);
522   }
523 
visitStorewasm::DeadCodeElimination524   void visitStore(Store* curr) {
525     blockifyReachableOperands({curr->ptr, curr->value}, curr->type);
526   }
527 
visitAtomicRMWwasm::DeadCodeElimination528   void visitAtomicRMW(AtomicRMW* curr) {
529     blockifyReachableOperands({curr->ptr, curr->value}, curr->type);
530   }
531 
visitAtomicCmpxchgwasm::DeadCodeElimination532   void visitAtomicCmpxchg(AtomicCmpxchg* curr) {
533     blockifyReachableOperands({curr->ptr, curr->expected, curr->replacement},
534                               curr->type);
535   }
536 
visitUnarywasm::DeadCodeElimination537   void visitUnary(Unary* curr) {
538     blockifyReachableOperands({curr->value}, curr->type);
539   }
540 
visitBinarywasm::DeadCodeElimination541   void visitBinary(Binary* curr) {
542     blockifyReachableOperands({curr->left, curr->right}, curr->type);
543   }
544 
visitSelectwasm::DeadCodeElimination545   void visitSelect(Select* curr) {
546     blockifyReachableOperands({curr->ifTrue, curr->ifFalse, curr->condition},
547                               curr->type);
548   }
549 
visitDropwasm::DeadCodeElimination550   void visitDrop(Drop* curr) {
551     blockifyReachableOperands({curr->value}, curr->type);
552   }
553 
visitMemorySizewasm::DeadCodeElimination554   void visitMemorySize(MemorySize* curr) {}
555 
visitMemoryGrowwasm::DeadCodeElimination556   void visitMemoryGrow(MemoryGrow* curr) {
557     blockifyReachableOperands({curr->delta}, curr->type);
558   }
559 
visitRefIsNullwasm::DeadCodeElimination560   void visitRefIsNull(RefIsNull* curr) {
561     blockifyReachableOperands({curr->value}, curr->type);
562   }
563 
visitRefEqwasm::DeadCodeElimination564   void visitRefEq(RefEq* curr) {
565     blockifyReachableOperands({curr->left, curr->right}, curr->type);
566   }
567 
visitFunctionwasm::DeadCodeElimination568   void visitFunction(Function* curr) { assert(reachableBreaks.size() == 0); }
569 };
570 
createDeadCodeEliminationPass()571 Pass* createDeadCodeEliminationPass() { return new DeadCodeElimination(); }
572 
573 } // namespace wasm
574