1 //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Adjust optimization to make the code more kernel verifier friendly.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "BPF.h"
14 #include "BPFCORE.h"
15 #include "BPFTargetMachine.h"
16 #include "llvm/IR/Instruction.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/IntrinsicsBPF.h"
19 #include "llvm/IR/Module.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/IR/Type.h"
22 #include "llvm/IR/User.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
26 
27 #define DEBUG_TYPE "bpf-adjust-opt"
28 
29 using namespace llvm;
30 using namespace llvm::PatternMatch;
31 
32 static cl::opt<bool>
33     DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
34                             cl::desc("BPF: Disable Serializing ICMP insns."),
35                             cl::init(false));
36 
37 static cl::opt<bool> DisableBPFavoidSpeculation(
38     "bpf-disable-avoid-speculation", cl::Hidden,
39     cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
40     cl::init(false));
41 
42 namespace {
43 
44 class BPFAdjustOpt final : public ModulePass {
45 public:
46   static char ID;
47 
48   BPFAdjustOpt() : ModulePass(ID) {}
49   bool runOnModule(Module &M) override;
50 };
51 
52 class BPFAdjustOptImpl {
53   struct PassThroughInfo {
54     Instruction *Input;
55     Instruction *UsedInst;
56     uint32_t OpIdx;
57     PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
58         : Input(I), UsedInst(U), OpIdx(Idx) {}
59   };
60 
61 public:
62   BPFAdjustOptImpl(Module *M) : M(M) {}
63 
64   bool run();
65 
66 private:
67   Module *M;
68   SmallVector<PassThroughInfo, 16> PassThroughs;
69 
70   bool adjustICmpToBuiltin();
71   void adjustBasicBlock(BasicBlock &BB);
72   bool serializeICMPCrossBB(BasicBlock &BB);
73   void adjustInst(Instruction &I);
74   bool serializeICMPInBB(Instruction &I);
75   bool avoidSpeculation(Instruction &I);
76   bool insertPassThrough();
77 };
78 
79 } // End anonymous namespace
80 
81 char BPFAdjustOpt::ID = 0;
82 INITIALIZE_PASS(BPFAdjustOpt, "bpf-adjust-opt", "BPF Adjust Optimization",
83                 false, false)
84 
85 ModulePass *llvm::createBPFAdjustOpt() { return new BPFAdjustOpt(); }
86 
87 bool BPFAdjustOpt::runOnModule(Module &M) { return BPFAdjustOptImpl(&M).run(); }
88 
89 bool BPFAdjustOptImpl::run() {
90   bool Changed = adjustICmpToBuiltin();
91 
92   for (Function &F : *M)
93     for (auto &BB : F) {
94       adjustBasicBlock(BB);
95       for (auto &I : BB)
96         adjustInst(I);
97     }
98   return insertPassThrough() || Changed;
99 }
100 
101 // Commit acabad9ff6bf ("[InstCombine] try to canonicalize icmp with
102 // trunc op into mask and cmp") added a transformation to
103 // convert "(conv)a < power_2_const" to "a & <const>" in certain
104 // cases and bpf kernel verifier has to handle the resulted code
105 // conservatively and this may reject otherwise legitimate program.
106 // Here, we change related icmp code to a builtin which will
107 // be restored to original icmp code later to prevent that
108 // InstCombine transformatin.
109 bool BPFAdjustOptImpl::adjustICmpToBuiltin() {
110   bool Changed = false;
111   ICmpInst *ToBeDeleted = nullptr;
112   for (Function &F : *M)
113     for (auto &BB : F)
114       for (auto &I : BB) {
115         if (ToBeDeleted) {
116           ToBeDeleted->eraseFromParent();
117           ToBeDeleted = nullptr;
118         }
119 
120         auto *Icmp = dyn_cast<ICmpInst>(&I);
121         if (!Icmp)
122           continue;
123 
124         Value *Op0 = Icmp->getOperand(0);
125         if (!isa<TruncInst>(Op0))
126           continue;
127 
128         auto ConstOp1 = dyn_cast<ConstantInt>(Icmp->getOperand(1));
129         if (!ConstOp1)
130           continue;
131 
132         auto ConstOp1Val = ConstOp1->getValue().getZExtValue();
133         auto Op = Icmp->getPredicate();
134         if (Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_UGE) {
135           if ((ConstOp1Val - 1) & ConstOp1Val)
136             continue;
137         } else if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_UGT) {
138           if (ConstOp1Val & (ConstOp1Val + 1))
139             continue;
140         } else {
141           continue;
142         }
143 
144         Constant *Opcode =
145             ConstantInt::get(Type::getInt32Ty(BB.getContext()), Op);
146         Function *Fn = Intrinsic::getDeclaration(
147             M, Intrinsic::bpf_compare, {Op0->getType(), ConstOp1->getType()});
148         auto *NewInst = CallInst::Create(Fn, {Opcode, Op0, ConstOp1});
149         BB.getInstList().insert(I.getIterator(), NewInst);
150         Icmp->replaceAllUsesWith(NewInst);
151         Changed = true;
152         ToBeDeleted = Icmp;
153       }
154 
155   return Changed;
156 }
157 
158 bool BPFAdjustOptImpl::insertPassThrough() {
159   for (auto &Info : PassThroughs) {
160     auto *CI = BPFCoreSharedInfo::insertPassThrough(
161         M, Info.UsedInst->getParent(), Info.Input, Info.UsedInst);
162     Info.UsedInst->setOperand(Info.OpIdx, CI);
163   }
164 
165   return !PassThroughs.empty();
166 }
167 
168 // To avoid combining conditionals in the same basic block by
169 // instrcombine optimization.
170 bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
171   // For:
172   //   comp1 = icmp <opcode> ...;
173   //   comp2 = icmp <opcode> ...;
174   //   ... or comp1 comp2 ...
175   // changed to:
176   //   comp1 = icmp <opcode> ...;
177   //   comp2 = icmp <opcode> ...;
178   //   new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
179   //   ... or new_comp1 comp2 ...
180   Value *Op0, *Op1;
181   // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`)
182   if (!match(&I, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
183     return false;
184   auto *Icmp1 = dyn_cast<ICmpInst>(Op0);
185   if (!Icmp1)
186     return false;
187   auto *Icmp2 = dyn_cast<ICmpInst>(Op1);
188   if (!Icmp2)
189     return false;
190 
191   Value *Icmp1Op0 = Icmp1->getOperand(0);
192   Value *Icmp2Op0 = Icmp2->getOperand(0);
193   if (Icmp1Op0 != Icmp2Op0)
194     return false;
195 
196   // Now we got two icmp instructions which feed into
197   // an "or" instruction.
198   PassThroughInfo Info(Icmp1, &I, 0);
199   PassThroughs.push_back(Info);
200   return true;
201 }
202 
203 // To avoid combining conditionals in the same basic block by
204 // instrcombine optimization.
205 bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
206   // For:
207   //   B1:
208   //     comp1 = icmp <opcode> ...;
209   //     if (comp1) goto B2 else B3;
210   //   B2:
211   //     comp2 = icmp <opcode> ...;
212   //     if (comp2) goto B4 else B5;
213   //   B4:
214   //     ...
215   // changed to:
216   //   B1:
217   //     comp1 = icmp <opcode> ...;
218   //     comp1 = __builtin_bpf_passthrough(seq_num, comp1);
219   //     if (comp1) goto B2 else B3;
220   //   B2:
221   //     comp2 = icmp <opcode> ...;
222   //     if (comp2) goto B4 else B5;
223   //   B4:
224   //     ...
225 
226   // Check basic predecessors, if two of them (say B1, B2) are using
227   // icmp instructions to generate conditions and one is the predesessor
228   // of another (e.g., B1 is the predecessor of B2). Add a passthrough
229   // barrier after icmp inst of block B1.
230   BasicBlock *B2 = BB.getSinglePredecessor();
231   if (!B2)
232     return false;
233 
234   BasicBlock *B1 = B2->getSinglePredecessor();
235   if (!B1)
236     return false;
237 
238   Instruction *TI = B2->getTerminator();
239   auto *BI = dyn_cast<BranchInst>(TI);
240   if (!BI || !BI->isConditional())
241     return false;
242   auto *Cond = dyn_cast<ICmpInst>(BI->getCondition());
243   if (!Cond || B2->getFirstNonPHI() != Cond)
244     return false;
245   Value *B2Op0 = Cond->getOperand(0);
246   auto Cond2Op = Cond->getPredicate();
247 
248   TI = B1->getTerminator();
249   BI = dyn_cast<BranchInst>(TI);
250   if (!BI || !BI->isConditional())
251     return false;
252   Cond = dyn_cast<ICmpInst>(BI->getCondition());
253   if (!Cond)
254     return false;
255   Value *B1Op0 = Cond->getOperand(0);
256   auto Cond1Op = Cond->getPredicate();
257 
258   if (B1Op0 != B2Op0)
259     return false;
260 
261   if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
262     if (Cond2Op != ICmpInst::ICMP_SLT && Cond2Op != ICmpInst::ICMP_SLE)
263       return false;
264   } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
265     if (Cond2Op != ICmpInst::ICMP_SGT && Cond2Op != ICmpInst::ICMP_SGE)
266       return false;
267   } else if (Cond1Op == ICmpInst::ICMP_ULT || Cond1Op == ICmpInst::ICMP_ULE) {
268     if (Cond2Op != ICmpInst::ICMP_UGT && Cond2Op != ICmpInst::ICMP_UGE)
269       return false;
270   } else if (Cond1Op == ICmpInst::ICMP_UGT || Cond1Op == ICmpInst::ICMP_UGE) {
271     if (Cond2Op != ICmpInst::ICMP_ULT && Cond2Op != ICmpInst::ICMP_ULE)
272       return false;
273   } else {
274     return false;
275   }
276 
277   PassThroughInfo Info(Cond, BI, 0);
278   PassThroughs.push_back(Info);
279 
280   return true;
281 }
282 
283 // To avoid speculative hoisting certain computations out of
284 // a basic block.
285 bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
286   if (auto *LdInst = dyn_cast<LoadInst>(&I)) {
287     if (auto *GV = dyn_cast<GlobalVariable>(LdInst->getOperand(0))) {
288       if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
289           GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
290         return false;
291     }
292   }
293 
294   if (!isa<LoadInst>(&I) && !isa<CallInst>(&I))
295     return false;
296 
297   // For:
298   //   B1:
299   //     var = ...
300   //     ...
301   //     /* icmp may not be in the same block as var = ... */
302   //     comp1 = icmp <opcode> var, <const>;
303   //     if (comp1) goto B2 else B3;
304   //   B2:
305   //     ... var ...
306   // change to:
307   //   B1:
308   //     var = ...
309   //     ...
310   //     /* icmp may not be in the same block as var = ... */
311   //     comp1 = icmp <opcode> var, <const>;
312   //     if (comp1) goto B2 else B3;
313   //   B2:
314   //     var = __builtin_bpf_passthrough(seq_num, var);
315   //     ... var ...
316   bool isCandidate = false;
317   SmallVector<PassThroughInfo, 4> Candidates;
318   for (User *U : I.users()) {
319     Instruction *Inst = dyn_cast<Instruction>(U);
320     if (!Inst)
321       continue;
322 
323     // May cover a little bit more than the
324     // above pattern.
325     if (auto *Icmp1 = dyn_cast<ICmpInst>(Inst)) {
326       Value *Icmp1Op1 = Icmp1->getOperand(1);
327       if (!isa<Constant>(Icmp1Op1))
328         return false;
329       isCandidate = true;
330       continue;
331     }
332 
333     // Ignore the use in the same basic block as the definition.
334     if (Inst->getParent() == I.getParent())
335       continue;
336 
337     // use in a different basic block, If there is a call or
338     // load/store insn before this instruction in this basic
339     // block. Most likely it cannot be hoisted out. Skip it.
340     for (auto &I2 : *Inst->getParent()) {
341       if (isa<CallInst>(&I2))
342         return false;
343       if (isa<LoadInst>(&I2) || isa<StoreInst>(&I2))
344         return false;
345       if (&I2 == Inst)
346         break;
347     }
348 
349     // It should be used in a GEP or a simple arithmetic like
350     // ZEXT/SEXT which is used for GEP.
351     if (Inst->getOpcode() == Instruction::ZExt ||
352         Inst->getOpcode() == Instruction::SExt) {
353       PassThroughInfo Info(&I, Inst, 0);
354       Candidates.push_back(Info);
355     } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
356       // traverse GEP inst to find Use operand index
357       unsigned i, e;
358       for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
359         Value *V = GI->getOperand(i);
360         if (V == &I)
361           break;
362       }
363       if (i == e)
364         continue;
365 
366       PassThroughInfo Info(&I, GI, i);
367       Candidates.push_back(Info);
368     }
369   }
370 
371   if (!isCandidate || Candidates.empty())
372     return false;
373 
374   llvm::append_range(PassThroughs, Candidates);
375   return true;
376 }
377 
378 void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
379   if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
380     return;
381 }
382 
383 void BPFAdjustOptImpl::adjustInst(Instruction &I) {
384   if (!DisableBPFserializeICMP && serializeICMPInBB(I))
385     return;
386   if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
387     return;
388 }
389 
390 PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
391   return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
392                                     : PreservedAnalyses::all();
393 }
394