1 //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Adjust optimization to make the code more kernel verifier friendly.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "BPF.h"
14 #include "BPFCORE.h"
15 #include "BPFTargetMachine.h"
16 #include "llvm/IR/Instruction.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IR/PatternMatch.h"
20 #include "llvm/IR/Type.h"
21 #include "llvm/IR/User.h"
22 #include "llvm/IR/Value.h"
23 #include "llvm/Pass.h"
24 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
25 
26 #define DEBUG_TYPE "bpf-adjust-opt"
27 
28 using namespace llvm;
29 using namespace llvm::PatternMatch;
30 
31 static cl::opt<bool>
32     DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
33                             cl::desc("BPF: Disable Serializing ICMP insns."),
34                             cl::init(false));
35 
36 static cl::opt<bool> DisableBPFavoidSpeculation(
37     "bpf-disable-avoid-speculation", cl::Hidden,
38     cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
39     cl::init(false));
40 
41 namespace {
42 
43 class BPFAdjustOpt final : public ModulePass {
44 public:
45   static char ID;
46 
BPFAdjustOpt()47   BPFAdjustOpt() : ModulePass(ID) {}
48   bool runOnModule(Module &M) override;
49 };
50 
51 class BPFAdjustOptImpl {
52   struct PassThroughInfo {
53     Instruction *Input;
54     Instruction *UsedInst;
55     uint32_t OpIdx;
PassThroughInfo__anon39def7c40111::BPFAdjustOptImpl::PassThroughInfo56     PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
57         : Input(I), UsedInst(U), OpIdx(Idx) {}
58   };
59 
60 public:
BPFAdjustOptImpl(Module * M)61   BPFAdjustOptImpl(Module *M) : M(M) {}
62 
63   bool run();
64 
65 private:
66   Module *M;
67   SmallVector<PassThroughInfo, 16> PassThroughs;
68 
69   void adjustBasicBlock(BasicBlock &BB);
70   bool serializeICMPCrossBB(BasicBlock &BB);
71   void adjustInst(Instruction &I);
72   bool serializeICMPInBB(Instruction &I);
73   bool avoidSpeculation(Instruction &I);
74   bool insertPassThrough();
75 };
76 
77 } // End anonymous namespace
78 
79 char BPFAdjustOpt::ID = 0;
80 INITIALIZE_PASS(BPFAdjustOpt, "bpf-adjust-opt", "BPF Adjust Optimization",
81                 false, false)
82 
createBPFAdjustOpt()83 ModulePass *llvm::createBPFAdjustOpt() { return new BPFAdjustOpt(); }
84 
runOnModule(Module & M)85 bool BPFAdjustOpt::runOnModule(Module &M) { return BPFAdjustOptImpl(&M).run(); }
86 
run()87 bool BPFAdjustOptImpl::run() {
88   for (Function &F : *M)
89     for (auto &BB : F) {
90       adjustBasicBlock(BB);
91       for (auto &I : BB)
92         adjustInst(I);
93     }
94 
95   return insertPassThrough();
96 }
97 
insertPassThrough()98 bool BPFAdjustOptImpl::insertPassThrough() {
99   for (auto &Info : PassThroughs) {
100     auto *CI = BPFCoreSharedInfo::insertPassThrough(
101         M, Info.UsedInst->getParent(), Info.Input, Info.UsedInst);
102     Info.UsedInst->setOperand(Info.OpIdx, CI);
103   }
104 
105   return !PassThroughs.empty();
106 }
107 
108 // To avoid combining conditionals in the same basic block by
109 // instrcombine optimization.
serializeICMPInBB(Instruction & I)110 bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
111   // For:
112   //   comp1 = icmp <opcode> ...;
113   //   comp2 = icmp <opcode> ...;
114   //   ... or comp1 comp2 ...
115   // changed to:
116   //   comp1 = icmp <opcode> ...;
117   //   comp2 = icmp <opcode> ...;
118   //   new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
119   //   ... or new_comp1 comp2 ...
120   Value *Op0, *Op1;
121   // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`)
122   if (!match(&I, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
123     return false;
124   auto *Icmp1 = dyn_cast<ICmpInst>(Op0);
125   if (!Icmp1)
126     return false;
127   auto *Icmp2 = dyn_cast<ICmpInst>(Op1);
128   if (!Icmp2)
129     return false;
130 
131   Value *Icmp1Op0 = Icmp1->getOperand(0);
132   Value *Icmp2Op0 = Icmp2->getOperand(0);
133   if (Icmp1Op0 != Icmp2Op0)
134     return false;
135 
136   // Now we got two icmp instructions which feed into
137   // an "or" instruction.
138   PassThroughInfo Info(Icmp1, &I, 0);
139   PassThroughs.push_back(Info);
140   return true;
141 }
142 
143 // To avoid combining conditionals in the same basic block by
144 // instrcombine optimization.
serializeICMPCrossBB(BasicBlock & BB)145 bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
146   // For:
147   //   B1:
148   //     comp1 = icmp <opcode> ...;
149   //     if (comp1) goto B2 else B3;
150   //   B2:
151   //     comp2 = icmp <opcode> ...;
152   //     if (comp2) goto B4 else B5;
153   //   B4:
154   //     ...
155   // changed to:
156   //   B1:
157   //     comp1 = icmp <opcode> ...;
158   //     comp1 = __builtin_bpf_passthrough(seq_num, comp1);
159   //     if (comp1) goto B2 else B3;
160   //   B2:
161   //     comp2 = icmp <opcode> ...;
162   //     if (comp2) goto B4 else B5;
163   //   B4:
164   //     ...
165 
166   // Check basic predecessors, if two of them (say B1, B2) are using
167   // icmp instructions to generate conditions and one is the predesessor
168   // of another (e.g., B1 is the predecessor of B2). Add a passthrough
169   // barrier after icmp inst of block B1.
170   BasicBlock *B2 = BB.getSinglePredecessor();
171   if (!B2)
172     return false;
173 
174   BasicBlock *B1 = B2->getSinglePredecessor();
175   if (!B1)
176     return false;
177 
178   Instruction *TI = B2->getTerminator();
179   auto *BI = dyn_cast<BranchInst>(TI);
180   if (!BI || !BI->isConditional())
181     return false;
182   auto *Cond = dyn_cast<ICmpInst>(BI->getCondition());
183   if (!Cond || B2->getFirstNonPHI() != Cond)
184     return false;
185   Value *B2Op0 = Cond->getOperand(0);
186   auto Cond2Op = Cond->getPredicate();
187 
188   TI = B1->getTerminator();
189   BI = dyn_cast<BranchInst>(TI);
190   if (!BI || !BI->isConditional())
191     return false;
192   Cond = dyn_cast<ICmpInst>(BI->getCondition());
193   if (!Cond)
194     return false;
195   Value *B1Op0 = Cond->getOperand(0);
196   auto Cond1Op = Cond->getPredicate();
197 
198   if (B1Op0 != B2Op0)
199     return false;
200 
201   if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
202     if (Cond2Op != ICmpInst::ICMP_SLT && Cond1Op != ICmpInst::ICMP_SLE)
203       return false;
204   } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
205     if (Cond2Op != ICmpInst::ICMP_SGT && Cond1Op != ICmpInst::ICMP_SGE)
206       return false;
207   } else {
208     return false;
209   }
210 
211   PassThroughInfo Info(Cond, BI, 0);
212   PassThroughs.push_back(Info);
213 
214   return true;
215 }
216 
217 // To avoid speculative hoisting certain computations out of
218 // a basic block.
avoidSpeculation(Instruction & I)219 bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
220   if (auto *LdInst = dyn_cast<LoadInst>(&I)) {
221     if (auto *GV = dyn_cast<GlobalVariable>(LdInst->getOperand(0))) {
222       if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
223           GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
224         return false;
225     }
226   }
227 
228   if (!isa<LoadInst>(&I) && !isa<CallInst>(&I))
229     return false;
230 
231   // For:
232   //   B1:
233   //     var = ...
234   //     ...
235   //     /* icmp may not be in the same block as var = ... */
236   //     comp1 = icmp <opcode> var, <const>;
237   //     if (comp1) goto B2 else B3;
238   //   B2:
239   //     ... var ...
240   // change to:
241   //   B1:
242   //     var = ...
243   //     ...
244   //     /* icmp may not be in the same block as var = ... */
245   //     comp1 = icmp <opcode> var, <const>;
246   //     if (comp1) goto B2 else B3;
247   //   B2:
248   //     var = __builtin_bpf_passthrough(seq_num, var);
249   //     ... var ...
250   bool isCandidate = false;
251   SmallVector<PassThroughInfo, 4> Candidates;
252   for (User *U : I.users()) {
253     Instruction *Inst = dyn_cast<Instruction>(U);
254     if (!Inst)
255       continue;
256 
257     // May cover a little bit more than the
258     // above pattern.
259     if (auto *Icmp1 = dyn_cast<ICmpInst>(Inst)) {
260       Value *Icmp1Op1 = Icmp1->getOperand(1);
261       if (!isa<Constant>(Icmp1Op1))
262         return false;
263       isCandidate = true;
264       continue;
265     }
266 
267     // Ignore the use in the same basic block as the definition.
268     if (Inst->getParent() == I.getParent())
269       continue;
270 
271     // use in a different basic block, If there is a call or
272     // load/store insn before this instruction in this basic
273     // block. Most likely it cannot be hoisted out. Skip it.
274     for (auto &I2 : *Inst->getParent()) {
275       if (isa<CallInst>(&I2))
276         return false;
277       if (isa<LoadInst>(&I2) || isa<StoreInst>(&I2))
278         return false;
279       if (&I2 == Inst)
280         break;
281     }
282 
283     // It should be used in a GEP or a simple arithmetic like
284     // ZEXT/SEXT which is used for GEP.
285     if (Inst->getOpcode() == Instruction::ZExt ||
286         Inst->getOpcode() == Instruction::SExt) {
287       PassThroughInfo Info(&I, Inst, 0);
288       Candidates.push_back(Info);
289     } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
290       // traverse GEP inst to find Use operand index
291       unsigned i, e;
292       for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
293         Value *V = GI->getOperand(i);
294         if (V == &I)
295           break;
296       }
297       if (i == e)
298         continue;
299 
300       PassThroughInfo Info(&I, GI, i);
301       Candidates.push_back(Info);
302     }
303   }
304 
305   if (!isCandidate || Candidates.empty())
306     return false;
307 
308   llvm::append_range(PassThroughs, Candidates);
309   return true;
310 }
311 
adjustBasicBlock(BasicBlock & BB)312 void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
313   if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
314     return;
315 }
316 
adjustInst(Instruction & I)317 void BPFAdjustOptImpl::adjustInst(Instruction &I) {
318   if (!DisableBPFserializeICMP && serializeICMPInBB(I))
319     return;
320   if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
321     return;
322 }
323 
run(Module & M,ModuleAnalysisManager & AM)324 PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
325   return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
326                                     : PreservedAnalyses::all();
327 }
328