1 //=== WebAssemblyLateEHPrepare.cpp - WebAssembly Exception Preparation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// \brief Does various transformations for exception handling.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
15 #include "WebAssembly.h"
16 #include "WebAssemblySubtarget.h"
17 #include "WebAssemblyUtilities.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/CodeGen/MachineInstrBuilder.h"
20 #include "llvm/CodeGen/WasmEHFuncInfo.h"
21 #include "llvm/MC/MCAsmInfo.h"
22 #include "llvm/Support/Debug.h"
23 using namespace llvm;
24 
25 #define DEBUG_TYPE "wasm-late-eh-prepare"
26 
27 namespace {
28 class WebAssemblyLateEHPrepare final : public MachineFunctionPass {
getPassName() const29   StringRef getPassName() const override {
30     return "WebAssembly Late Prepare Exception";
31   }
32 
33   bool runOnMachineFunction(MachineFunction &MF) override;
34   bool addCatches(MachineFunction &MF);
35   bool replaceFuncletReturns(MachineFunction &MF);
36   bool removeUnnecessaryUnreachables(MachineFunction &MF);
37   bool addExceptionExtraction(MachineFunction &MF);
38   bool restoreStackPointer(MachineFunction &MF);
39 
40 public:
41   static char ID; // Pass identification, replacement for typeid
WebAssemblyLateEHPrepare()42   WebAssemblyLateEHPrepare() : MachineFunctionPass(ID) {}
43 };
44 } // end anonymous namespace
45 
46 char WebAssemblyLateEHPrepare::ID = 0;
47 INITIALIZE_PASS(WebAssemblyLateEHPrepare, DEBUG_TYPE,
48                 "WebAssembly Late Exception Preparation", false, false)
49 
createWebAssemblyLateEHPrepare()50 FunctionPass *llvm::createWebAssemblyLateEHPrepare() {
51   return new WebAssemblyLateEHPrepare();
52 }
53 
54 // Returns the nearest EH pad that dominates this instruction. This does not use
55 // dominator analysis; it just does BFS on its predecessors until arriving at an
56 // EH pad. This assumes valid EH scopes so the first EH pad it arrives in all
57 // possible search paths should be the same.
58 // Returns nullptr in case it does not find any EH pad in the search, or finds
59 // multiple different EH pads.
getMatchingEHPad(MachineInstr * MI)60 static MachineBasicBlock *getMatchingEHPad(MachineInstr *MI) {
61   MachineFunction *MF = MI->getParent()->getParent();
62   SmallVector<MachineBasicBlock *, 2> WL;
63   SmallPtrSet<MachineBasicBlock *, 2> Visited;
64   WL.push_back(MI->getParent());
65   MachineBasicBlock *EHPad = nullptr;
66   while (!WL.empty()) {
67     MachineBasicBlock *MBB = WL.pop_back_val();
68     if (Visited.count(MBB))
69       continue;
70     Visited.insert(MBB);
71     if (MBB->isEHPad()) {
72       if (EHPad && EHPad != MBB)
73         return nullptr;
74       EHPad = MBB;
75       continue;
76     }
77     if (MBB == &MF->front())
78       return nullptr;
79     WL.append(MBB->pred_begin(), MBB->pred_end());
80   }
81   return EHPad;
82 }
83 
84 // Erase the specified BBs if the BB does not have any remaining predecessors,
85 // and also all its dead children.
86 template <typename Container>
eraseDeadBBsAndChildren(const Container & MBBs)87 static void eraseDeadBBsAndChildren(const Container &MBBs) {
88   SmallVector<MachineBasicBlock *, 8> WL(MBBs.begin(), MBBs.end());
89   while (!WL.empty()) {
90     MachineBasicBlock *MBB = WL.pop_back_val();
91     if (!MBB->pred_empty())
92       continue;
93     SmallVector<MachineBasicBlock *, 4> Succs(MBB->succ_begin(),
94                                               MBB->succ_end());
95     WL.append(MBB->succ_begin(), MBB->succ_end());
96     for (auto *Succ : Succs)
97       MBB->removeSuccessor(Succ);
98     MBB->eraseFromParent();
99   }
100 }
101 
runOnMachineFunction(MachineFunction & MF)102 bool WebAssemblyLateEHPrepare::runOnMachineFunction(MachineFunction &MF) {
103   LLVM_DEBUG(dbgs() << "********** Late EH Prepare **********\n"
104                        "********** Function: "
105                     << MF.getName() << '\n');
106 
107   if (MF.getTarget().getMCAsmInfo()->getExceptionHandlingType() !=
108       ExceptionHandling::Wasm)
109     return false;
110 
111   bool Changed = false;
112   if (MF.getFunction().hasPersonalityFn()) {
113     Changed |= addCatches(MF);
114     Changed |= replaceFuncletReturns(MF);
115   }
116   Changed |= removeUnnecessaryUnreachables(MF);
117   if (MF.getFunction().hasPersonalityFn()) {
118     Changed |= addExceptionExtraction(MF);
119     Changed |= restoreStackPointer(MF);
120   }
121   return Changed;
122 }
123 
124 // Add catch instruction to beginning of catchpads and cleanuppads.
addCatches(MachineFunction & MF)125 bool WebAssemblyLateEHPrepare::addCatches(MachineFunction &MF) {
126   bool Changed = false;
127   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
128   MachineRegisterInfo &MRI = MF.getRegInfo();
129   for (auto &MBB : MF) {
130     if (MBB.isEHPad()) {
131       Changed = true;
132       auto InsertPos = MBB.begin();
133       if (InsertPos->isEHLabel()) // EH pad starts with an EH label
134         ++InsertPos;
135       Register DstReg = MRI.createVirtualRegister(&WebAssembly::EXNREFRegClass);
136       BuildMI(MBB, InsertPos, MBB.begin()->getDebugLoc(),
137               TII.get(WebAssembly::CATCH), DstReg);
138     }
139   }
140   return Changed;
141 }
142 
replaceFuncletReturns(MachineFunction & MF)143 bool WebAssemblyLateEHPrepare::replaceFuncletReturns(MachineFunction &MF) {
144   bool Changed = false;
145   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
146 
147   for (auto &MBB : MF) {
148     auto Pos = MBB.getFirstTerminator();
149     if (Pos == MBB.end())
150       continue;
151     MachineInstr *TI = &*Pos;
152 
153     switch (TI->getOpcode()) {
154     case WebAssembly::CATCHRET: {
155       // Replace a catchret with a branch
156       MachineBasicBlock *TBB = TI->getOperand(0).getMBB();
157       if (!MBB.isLayoutSuccessor(TBB))
158         BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::BR))
159             .addMBB(TBB);
160       TI->eraseFromParent();
161       Changed = true;
162       break;
163     }
164     case WebAssembly::CLEANUPRET:
165     case WebAssembly::RETHROW_IN_CATCH: {
166       // Replace a cleanupret/rethrow_in_catch with a rethrow
167       auto *EHPad = getMatchingEHPad(TI);
168       auto CatchPos = EHPad->begin();
169       if (CatchPos->isEHLabel()) // EH pad starts with an EH label
170         ++CatchPos;
171       MachineInstr *Catch = &*CatchPos;
172       Register ExnReg = Catch->getOperand(0).getReg();
173       BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::RETHROW))
174           .addReg(ExnReg);
175       TI->eraseFromParent();
176       Changed = true;
177       break;
178     }
179     }
180   }
181   return Changed;
182 }
183 
removeUnnecessaryUnreachables(MachineFunction & MF)184 bool WebAssemblyLateEHPrepare::removeUnnecessaryUnreachables(
185     MachineFunction &MF) {
186   bool Changed = false;
187   for (auto &MBB : MF) {
188     for (auto &MI : MBB) {
189       if (MI.getOpcode() != WebAssembly::THROW &&
190           MI.getOpcode() != WebAssembly::RETHROW)
191         continue;
192       Changed = true;
193 
194       // The instruction after the throw should be an unreachable or a branch to
195       // another BB that should eventually lead to an unreachable. Delete it
196       // because throw itself is a terminator, and also delete successors if
197       // any.
198       MBB.erase(std::next(MI.getIterator()), MBB.end());
199       SmallVector<MachineBasicBlock *, 8> Succs(MBB.succ_begin(),
200                                                 MBB.succ_end());
201       for (auto *Succ : Succs)
202         if (!Succ->isEHPad())
203           MBB.removeSuccessor(Succ);
204       eraseDeadBBsAndChildren(Succs);
205     }
206   }
207 
208   return Changed;
209 }
210 
211 // Wasm uses 'br_on_exn' instruction to check the tag of an exception. It takes
212 // exnref type object returned by 'catch', and branches to the destination if it
213 // matches a given tag. We currently use __cpp_exception symbol to represent the
214 // tag for all C++ exceptions.
215 //
216 // block $l (result i32)
217 //   ...
218 //   ;; exnref $e is on the stack at this point
219 //   br_on_exn $l $e ;; branch to $l with $e's arguments
220 //   ...
221 // end
222 // ;; Here we expect the extracted values are on top of the wasm value stack
223 // ... Handle exception using values ...
224 //
225 // br_on_exn takes an exnref object and branches if it matches the given tag.
226 // There can be multiple br_on_exn instructions if we want to match for another
227 // tag, but for now we only test for __cpp_exception tag, and if it does not
228 // match, i.e., it is a foreign exception, we rethrow it.
229 //
230 // In the destination BB that's the target of br_on_exn, extracted exception
231 // values (in C++'s case a single i32, which represents an exception pointer)
232 // are placed on top of the wasm stack. Because we can't model wasm stack in
233 // LLVM instruction, we use 'extract_exception' pseudo instruction to retrieve
234 // it. The pseudo instruction will be deleted later.
addExceptionExtraction(MachineFunction & MF)235 bool WebAssemblyLateEHPrepare::addExceptionExtraction(MachineFunction &MF) {
236   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
237   MachineRegisterInfo &MRI = MF.getRegInfo();
238   auto *EHInfo = MF.getWasmEHFuncInfo();
239   SmallVector<MachineInstr *, 16> ExtractInstrs;
240   SmallVector<MachineInstr *, 8> ToDelete;
241   for (auto &MBB : MF) {
242     for (auto &MI : MBB) {
243       if (MI.getOpcode() == WebAssembly::EXTRACT_EXCEPTION_I32) {
244         if (MI.getOperand(0).isDead())
245           ToDelete.push_back(&MI);
246         else
247           ExtractInstrs.push_back(&MI);
248       }
249     }
250   }
251   bool Changed = !ToDelete.empty() || !ExtractInstrs.empty();
252   for (auto *MI : ToDelete)
253     MI->eraseFromParent();
254   if (ExtractInstrs.empty())
255     return Changed;
256 
257   // Find terminate pads.
258   SmallSet<MachineBasicBlock *, 8> TerminatePads;
259   for (auto &MBB : MF) {
260     for (auto &MI : MBB) {
261       if (MI.isCall()) {
262         const MachineOperand &CalleeOp = MI.getOperand(0);
263         if (CalleeOp.isGlobal() && CalleeOp.getGlobal()->getName() ==
264                                        WebAssembly::ClangCallTerminateFn)
265           TerminatePads.insert(getMatchingEHPad(&MI));
266       }
267     }
268   }
269 
270   for (auto *Extract : ExtractInstrs) {
271     MachineBasicBlock *EHPad = getMatchingEHPad(Extract);
272     assert(EHPad && "No matching EH pad for extract_exception");
273     auto CatchPos = EHPad->begin();
274     if (CatchPos->isEHLabel()) // EH pad starts with an EH label
275       ++CatchPos;
276     MachineInstr *Catch = &*CatchPos;
277 
278     if (Catch->getNextNode() != Extract)
279       EHPad->insert(Catch->getNextNode(), Extract->removeFromParent());
280 
281     // - Before:
282     // ehpad:
283     //   %exnref:exnref = catch
284     //   %exn:i32 = extract_exception
285     //   ... use exn ...
286     //
287     // - After:
288     // ehpad:
289     //   %exnref:exnref = catch
290     //   br_on_exn %thenbb, $__cpp_exception, %exnref
291     //   br %elsebb
292     // elsebb:
293     //   rethrow
294     // thenbb:
295     //   %exn:i32 = extract_exception
296     //   ... use exn ...
297     Register ExnReg = Catch->getOperand(0).getReg();
298     auto *ThenMBB = MF.CreateMachineBasicBlock();
299     auto *ElseMBB = MF.CreateMachineBasicBlock();
300     MF.insert(std::next(MachineFunction::iterator(EHPad)), ElseMBB);
301     MF.insert(std::next(MachineFunction::iterator(ElseMBB)), ThenMBB);
302     ThenMBB->splice(ThenMBB->end(), EHPad, Extract, EHPad->end());
303     ThenMBB->transferSuccessors(EHPad);
304     EHPad->addSuccessor(ThenMBB);
305     EHPad->addSuccessor(ElseMBB);
306 
307     DebugLoc DL = Extract->getDebugLoc();
308     const char *CPPExnSymbol = MF.createExternalSymbolName("__cpp_exception");
309     BuildMI(EHPad, DL, TII.get(WebAssembly::BR_ON_EXN))
310         .addMBB(ThenMBB)
311         .addExternalSymbol(CPPExnSymbol)
312         .addReg(ExnReg);
313     BuildMI(EHPad, DL, TII.get(WebAssembly::BR)).addMBB(ElseMBB);
314 
315     // When this is a terminate pad with __clang_call_terminate() call, we don't
316     // rethrow it anymore and call __clang_call_terminate() with a nullptr
317     // argument, which will call std::terminate().
318     //
319     // - Before:
320     // ehpad:
321     //   %exnref:exnref = catch
322     //   %exn:i32 = extract_exception
323     //   call @__clang_call_terminate(%exn)
324     //   unreachable
325     //
326     // - After:
327     // ehpad:
328     //   %exnref:exnref = catch
329     //   br_on_exn %thenbb, $__cpp_exception, %exnref
330     //   br %elsebb
331     // elsebb:
332     //   call @__clang_call_terminate(0)
333     //   unreachable
334     // thenbb:
335     //   %exn:i32 = extract_exception
336     //   call @__clang_call_terminate(%exn)
337     //   unreachable
338     if (TerminatePads.count(EHPad)) {
339       Function *ClangCallTerminateFn =
340           MF.getFunction().getParent()->getFunction(
341               WebAssembly::ClangCallTerminateFn);
342       assert(ClangCallTerminateFn &&
343              "There is no __clang_call_terminate() function");
344       Register Reg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
345       BuildMI(ElseMBB, DL, TII.get(WebAssembly::CONST_I32), Reg).addImm(0);
346       BuildMI(ElseMBB, DL, TII.get(WebAssembly::CALL_VOID))
347           .addGlobalAddress(ClangCallTerminateFn)
348           .addReg(Reg);
349       BuildMI(ElseMBB, DL, TII.get(WebAssembly::UNREACHABLE));
350 
351     } else {
352       BuildMI(ElseMBB, DL, TII.get(WebAssembly::RETHROW)).addReg(ExnReg);
353       if (EHInfo->hasEHPadUnwindDest(EHPad))
354         ElseMBB->addSuccessor(EHInfo->getEHPadUnwindDest(EHPad));
355     }
356   }
357 
358   return true;
359 }
360 
361 // After the stack is unwound due to a thrown exception, the __stack_pointer
362 // global can point to an invalid address. This inserts instructions that
363 // restore __stack_pointer global.
restoreStackPointer(MachineFunction & MF)364 bool WebAssemblyLateEHPrepare::restoreStackPointer(MachineFunction &MF) {
365   const auto *FrameLowering = static_cast<const WebAssemblyFrameLowering *>(
366       MF.getSubtarget().getFrameLowering());
367   if (!FrameLowering->needsPrologForEH(MF))
368     return false;
369   bool Changed = false;
370 
371   for (auto &MBB : MF) {
372     if (!MBB.isEHPad())
373       continue;
374     Changed = true;
375 
376     // Insert __stack_pointer restoring instructions at the beginning of each EH
377     // pad, after the catch instruction. Here it is safe to assume that SP32
378     // holds the latest value of __stack_pointer, because the only exception for
379     // this case is when a function uses the red zone, but that only happens
380     // with leaf functions, and we don't restore __stack_pointer in leaf
381     // functions anyway.
382     auto InsertPos = MBB.begin();
383     if (InsertPos->isEHLabel()) // EH pad starts with an EH label
384       ++InsertPos;
385     if (InsertPos->getOpcode() == WebAssembly::CATCH)
386       ++InsertPos;
387     FrameLowering->writeSPToGlobal(WebAssembly::SP32, MF, MBB, InsertPos,
388                                    MBB.begin()->getDebugLoc());
389   }
390   return Changed;
391 }
392