1 //===-- NVPTXLowerUnreachable.cpp - Lower unreachables to exit =====--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // PTX does not have a notion of `unreachable`, which results in emitted basic
10 // blocks having an edge to the next block:
11 //
12 //   block1:
13 //     call @does_not_return();
14 //     // unreachable
15 //   block2:
16 //     // ptxas will create a CFG edge from block1 to block2
17 //
18 // This may result in significant changes to the control flow graph, e.g., when
19 // LLVM moves unreachable blocks to the end of the function. That's a problem
20 // in the context of divergent control flow, as `ptxas` uses the CFG to
21 // determine divergent regions, and some intructions may not be executed
22 // divergently.
23 //
24 // For example, `bar.sync` is not allowed to be executed divergently on Pascal
25 // or earlier. If we start with the following:
26 //
27 //   entry:
28 //     // start of divergent region
29 //     @%p0 bra cont;
30 //     @%p1 bra unlikely;
31 //     ...
32 //     bra.uni cont;
33 //   unlikely:
34 //     ...
35 //     // unreachable
36 //   cont:
37 //     // end of divergent region
38 //     bar.sync 0;
39 //     bra.uni exit;
40 //   exit:
41 //     ret;
42 //
43 // it is transformed by the branch-folder and block-placement passes to:
44 //
45 //   entry:
46 //     // start of divergent region
47 //     @%p0 bra cont;
48 //     @%p1 bra unlikely;
49 //     ...
50 //     bra.uni cont;
51 //   cont:
52 //     bar.sync 0;
53 //     bra.uni exit;
54 //   unlikely:
55 //     ...
56 //     // unreachable
57 //   exit:
58 //     // end of divergent region
59 //     ret;
60 //
61 // After moving the `unlikely` block to the end of the function, it has an edge
62 // to the `exit` block, which widens the divergent region and makes the
63 // `bar.sync` instruction happen divergently.
64 //
65 // To work around this, we add an `exit` instruction before every `unreachable`,
66 // as `ptxas` understands that exit terminates the CFG. Note that `trap` is not
67 // equivalent, and only future versions of `ptxas` will model it like `exit`.
68 //
69 //===----------------------------------------------------------------------===//
70 
71 #include "NVPTX.h"
72 #include "llvm/IR/Function.h"
73 #include "llvm/IR/InlineAsm.h"
74 #include "llvm/IR/Instructions.h"
75 #include "llvm/IR/Type.h"
76 #include "llvm/Pass.h"
77 
78 using namespace llvm;
79 
80 namespace llvm {
81 void initializeNVPTXLowerUnreachablePass(PassRegistry &);
82 }
83 
84 namespace {
85 class NVPTXLowerUnreachable : public FunctionPass {
86   bool runOnFunction(Function &F) override;
87 
88 public:
89   static char ID; // Pass identification, replacement for typeid
90   NVPTXLowerUnreachable() : FunctionPass(ID) {}
91   StringRef getPassName() const override {
92     return "add an exit instruction before every unreachable";
93   }
94 };
95 } // namespace
96 
97 char NVPTXLowerUnreachable::ID = 1;
98 
99 INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable",
100                 "Lower Unreachable", false, false)
101 
102 // =============================================================================
103 // Main function for this pass.
104 // =============================================================================
105 bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
106   if (skipFunction(F))
107     return false;
108 
109   LLVMContext &C = F.getContext();
110   FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false);
111   InlineAsm *Exit = InlineAsm::get(ExitFTy, "exit;", "", true);
112 
113   bool Changed = false;
114   for (auto &BB : F)
115     for (auto &I : BB) {
116       if (auto unreachableInst = dyn_cast<UnreachableInst>(&I)) {
117         Changed = true;
118         CallInst::Create(ExitFTy, Exit, "", unreachableInst);
119       }
120     }
121   return Changed;
122 }
123 
124 FunctionPass *llvm::createNVPTXLowerUnreachablePass() {
125   return new NVPTXLowerUnreachable();
126 }
127