1 //===-------------- BPFMIPeephole.cpp - MI Peephole Cleanups  -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass performs peephole optimizations to cleanup ugly code sequences at
10 // MachineInstruction layer.
11 //
12 // Currently, there are two optimizations implemented:
13 //  - One pre-RA MachineSSA pass to eliminate type promotion sequences, those
14 //    zero extend 32-bit subregisters to 64-bit registers, if the compiler
15 //    could prove the subregisters is defined by 32-bit operations in which
16 //    case the upper half of the underlying 64-bit registers were zeroed
17 //    implicitly.
18 //
19 //  - One post-RA PreEmit pass to do final cleanup on some redundant
20 //    instructions generated due to bad RA on subregister.
21 //===----------------------------------------------------------------------===//
22 
23 #include "BPF.h"
24 #include "BPFInstrInfo.h"
25 #include "BPFTargetMachine.h"
26 #include "llvm/ADT/Statistic.h"
27 #include "llvm/CodeGen/MachineInstrBuilder.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/Support/Debug.h"
30 #include <set>
31 
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "bpf-mi-zext-elim"
35 
36 STATISTIC(ZExtElemNum, "Number of zero extension shifts eliminated");
37 
38 namespace {
39 
40 struct BPFMIPeephole : public MachineFunctionPass {
41 
42   static char ID;
43   const BPFInstrInfo *TII;
44   MachineFunction *MF;
45   MachineRegisterInfo *MRI;
46 
47   BPFMIPeephole() : MachineFunctionPass(ID) {
48     initializeBPFMIPeepholePass(*PassRegistry::getPassRegistry());
49   }
50 
51 private:
52   // Initialize class variables.
53   void initialize(MachineFunction &MFParm);
54 
55   bool isCopyFrom32Def(MachineInstr *CopyMI);
56   bool isInsnFrom32Def(MachineInstr *DefInsn);
57   bool isPhiFrom32Def(MachineInstr *MovMI);
58   bool isMovFrom32Def(MachineInstr *MovMI);
59   bool eliminateZExtSeq();
60   bool eliminateZExt();
61 
62   std::set<MachineInstr *> PhiInsns;
63 
64 public:
65 
66   // Main entry point for this pass.
67   bool runOnMachineFunction(MachineFunction &MF) override {
68     if (skipFunction(MF.getFunction()))
69       return false;
70 
71     initialize(MF);
72 
73     // First try to eliminate (zext, lshift, rshift) and then
74     // try to eliminate zext.
75     bool ZExtSeqExist, ZExtExist;
76     ZExtSeqExist = eliminateZExtSeq();
77     ZExtExist = eliminateZExt();
78     return ZExtSeqExist || ZExtExist;
79   }
80 };
81 
82 // Initialize class variables.
83 void BPFMIPeephole::initialize(MachineFunction &MFParm) {
84   MF = &MFParm;
85   MRI = &MF->getRegInfo();
86   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
87   LLVM_DEBUG(dbgs() << "*** BPF MachineSSA ZEXT Elim peephole pass ***\n\n");
88 }
89 
90 bool BPFMIPeephole::isCopyFrom32Def(MachineInstr *CopyMI)
91 {
92   MachineOperand &opnd = CopyMI->getOperand(1);
93 
94   if (!opnd.isReg())
95     return false;
96 
97   // Return false if getting value from a 32bit physical register.
98   // Most likely, this physical register is aliased to
99   // function call return value or current function parameters.
100   Register Reg = opnd.getReg();
101   if (!Register::isVirtualRegister(Reg))
102     return false;
103 
104   if (MRI->getRegClass(Reg) == &BPF::GPRRegClass)
105     return false;
106 
107   MachineInstr *DefInsn = MRI->getVRegDef(Reg);
108   if (!isInsnFrom32Def(DefInsn))
109     return false;
110 
111   return true;
112 }
113 
114 bool BPFMIPeephole::isPhiFrom32Def(MachineInstr *PhiMI)
115 {
116   for (unsigned i = 1, e = PhiMI->getNumOperands(); i < e; i += 2) {
117     MachineOperand &opnd = PhiMI->getOperand(i);
118 
119     if (!opnd.isReg())
120       return false;
121 
122     MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
123     if (!PhiDef)
124       return false;
125     if (PhiDef->isPHI()) {
126       if (PhiInsns.find(PhiDef) != PhiInsns.end())
127         return false;
128       PhiInsns.insert(PhiDef);
129       if (!isPhiFrom32Def(PhiDef))
130         return false;
131     }
132     if (PhiDef->getOpcode() == BPF::COPY && !isCopyFrom32Def(PhiDef))
133       return false;
134   }
135 
136   return true;
137 }
138 
139 // The \p DefInsn instruction defines a virtual register.
140 bool BPFMIPeephole::isInsnFrom32Def(MachineInstr *DefInsn)
141 {
142   if (!DefInsn)
143     return false;
144 
145   if (DefInsn->isPHI()) {
146     if (PhiInsns.find(DefInsn) != PhiInsns.end())
147       return false;
148     PhiInsns.insert(DefInsn);
149     if (!isPhiFrom32Def(DefInsn))
150       return false;
151   } else if (DefInsn->getOpcode() == BPF::COPY) {
152     if (!isCopyFrom32Def(DefInsn))
153       return false;
154   }
155 
156   return true;
157 }
158 
159 bool BPFMIPeephole::isMovFrom32Def(MachineInstr *MovMI)
160 {
161   MachineInstr *DefInsn = MRI->getVRegDef(MovMI->getOperand(1).getReg());
162 
163   LLVM_DEBUG(dbgs() << "  Def of Mov Src:");
164   LLVM_DEBUG(DefInsn->dump());
165 
166   PhiInsns.clear();
167   if (!isInsnFrom32Def(DefInsn))
168     return false;
169 
170   LLVM_DEBUG(dbgs() << "  One ZExt elim sequence identified.\n");
171 
172   return true;
173 }
174 
175 bool BPFMIPeephole::eliminateZExtSeq() {
176   MachineInstr* ToErase = nullptr;
177   bool Eliminated = false;
178 
179   for (MachineBasicBlock &MBB : *MF) {
180     for (MachineInstr &MI : MBB) {
181       // If the previous instruction was marked for elimination, remove it now.
182       if (ToErase) {
183         ToErase->eraseFromParent();
184         ToErase = nullptr;
185       }
186 
187       // Eliminate the 32-bit to 64-bit zero extension sequence when possible.
188       //
189       //   MOV_32_64 rB, wA
190       //   SLL_ri    rB, rB, 32
191       //   SRL_ri    rB, rB, 32
192       if (MI.getOpcode() == BPF::SRL_ri &&
193           MI.getOperand(2).getImm() == 32) {
194         Register DstReg = MI.getOperand(0).getReg();
195         Register ShfReg = MI.getOperand(1).getReg();
196         MachineInstr *SllMI = MRI->getVRegDef(ShfReg);
197 
198         LLVM_DEBUG(dbgs() << "Starting SRL found:");
199         LLVM_DEBUG(MI.dump());
200 
201         if (!SllMI ||
202             SllMI->isPHI() ||
203             SllMI->getOpcode() != BPF::SLL_ri ||
204             SllMI->getOperand(2).getImm() != 32)
205           continue;
206 
207         LLVM_DEBUG(dbgs() << "  SLL found:");
208         LLVM_DEBUG(SllMI->dump());
209 
210         MachineInstr *MovMI = MRI->getVRegDef(SllMI->getOperand(1).getReg());
211         if (!MovMI ||
212             MovMI->isPHI() ||
213             MovMI->getOpcode() != BPF::MOV_32_64)
214           continue;
215 
216         LLVM_DEBUG(dbgs() << "  Type cast Mov found:");
217         LLVM_DEBUG(MovMI->dump());
218 
219         Register SubReg = MovMI->getOperand(1).getReg();
220         if (!isMovFrom32Def(MovMI)) {
221           LLVM_DEBUG(dbgs()
222                      << "  One ZExt elim sequence failed qualifying elim.\n");
223           continue;
224         }
225 
226         BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), DstReg)
227           .addImm(0).addReg(SubReg).addImm(BPF::sub_32);
228 
229         SllMI->eraseFromParent();
230         MovMI->eraseFromParent();
231         // MI is the right shift, we can't erase it in it's own iteration.
232         // Mark it to ToErase, and erase in the next iteration.
233         ToErase = &MI;
234         ZExtElemNum++;
235         Eliminated = true;
236       }
237     }
238   }
239 
240   return Eliminated;
241 }
242 
243 bool BPFMIPeephole::eliminateZExt() {
244   MachineInstr* ToErase = nullptr;
245   bool Eliminated = false;
246 
247   for (MachineBasicBlock &MBB : *MF) {
248     for (MachineInstr &MI : MBB) {
249       // If the previous instruction was marked for elimination, remove it now.
250       if (ToErase) {
251         ToErase->eraseFromParent();
252         ToErase = nullptr;
253       }
254 
255       if (MI.getOpcode() != BPF::MOV_32_64)
256         continue;
257 
258       // Eliminate MOV_32_64 if possible.
259       //   MOV_32_64 rA, wB
260       //
261       // If wB has been zero extended, replace it with a SUBREG_TO_REG.
262       // This is to workaround BPF programs where pkt->{data, data_end}
263       // is encoded as u32, but actually the verifier populates them
264       // as 64bit pointer. The MOV_32_64 will zero out the top 32 bits.
265       LLVM_DEBUG(dbgs() << "Candidate MOV_32_64 instruction:");
266       LLVM_DEBUG(MI.dump());
267 
268       if (!isMovFrom32Def(&MI))
269         continue;
270 
271       LLVM_DEBUG(dbgs() << "Removing the MOV_32_64 instruction\n");
272 
273       Register dst = MI.getOperand(0).getReg();
274       Register src = MI.getOperand(1).getReg();
275 
276       // Build a SUBREG_TO_REG instruction.
277       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), dst)
278         .addImm(0).addReg(src).addImm(BPF::sub_32);
279 
280       ToErase = &MI;
281       Eliminated = true;
282     }
283   }
284 
285   return Eliminated;
286 }
287 
288 } // end default namespace
289 
290 INITIALIZE_PASS(BPFMIPeephole, DEBUG_TYPE,
291                 "BPF MachineSSA Peephole Optimization For ZEXT Eliminate",
292                 false, false)
293 
294 char BPFMIPeephole::ID = 0;
295 FunctionPass* llvm::createBPFMIPeepholePass() { return new BPFMIPeephole(); }
296 
297 STATISTIC(RedundantMovElemNum, "Number of redundant moves eliminated");
298 
299 namespace {
300 
301 struct BPFMIPreEmitPeephole : public MachineFunctionPass {
302 
303   static char ID;
304   MachineFunction *MF;
305   const TargetRegisterInfo *TRI;
306 
307   BPFMIPreEmitPeephole() : MachineFunctionPass(ID) {
308     initializeBPFMIPreEmitPeepholePass(*PassRegistry::getPassRegistry());
309   }
310 
311 private:
312   // Initialize class variables.
313   void initialize(MachineFunction &MFParm);
314 
315   bool eliminateRedundantMov();
316 
317 public:
318 
319   // Main entry point for this pass.
320   bool runOnMachineFunction(MachineFunction &MF) override {
321     if (skipFunction(MF.getFunction()))
322       return false;
323 
324     initialize(MF);
325 
326     return eliminateRedundantMov();
327   }
328 };
329 
330 // Initialize class variables.
331 void BPFMIPreEmitPeephole::initialize(MachineFunction &MFParm) {
332   MF = &MFParm;
333   TRI = MF->getSubtarget<BPFSubtarget>().getRegisterInfo();
334   LLVM_DEBUG(dbgs() << "*** BPF PreEmit peephole pass ***\n\n");
335 }
336 
337 bool BPFMIPreEmitPeephole::eliminateRedundantMov() {
338   MachineInstr* ToErase = nullptr;
339   bool Eliminated = false;
340 
341   for (MachineBasicBlock &MBB : *MF) {
342     for (MachineInstr &MI : MBB) {
343       // If the previous instruction was marked for elimination, remove it now.
344       if (ToErase) {
345         LLVM_DEBUG(dbgs() << "  Redundant Mov Eliminated:");
346         LLVM_DEBUG(ToErase->dump());
347         ToErase->eraseFromParent();
348         ToErase = nullptr;
349       }
350 
351       // Eliminate identical move:
352       //
353       //   MOV rA, rA
354       //
355       // Note that we cannot remove
356       //   MOV_32_64  rA, wA
357       //   MOV_rr_32  wA, wA
358       // as these two instructions having side effects, zeroing out
359       // top 32 bits of rA.
360       unsigned Opcode = MI.getOpcode();
361       if (Opcode == BPF::MOV_rr) {
362         Register dst = MI.getOperand(0).getReg();
363         Register src = MI.getOperand(1).getReg();
364 
365         if (dst != src)
366           continue;
367 
368         ToErase = &MI;
369         RedundantMovElemNum++;
370         Eliminated = true;
371       }
372     }
373   }
374 
375   return Eliminated;
376 }
377 
378 } // end default namespace
379 
380 INITIALIZE_PASS(BPFMIPreEmitPeephole, "bpf-mi-pemit-peephole",
381                 "BPF PreEmit Peephole Optimization", false, false)
382 
383 char BPFMIPreEmitPeephole::ID = 0;
384 FunctionPass* llvm::createBPFMIPreEmitPeepholePass()
385 {
386   return new BPFMIPreEmitPeephole();
387 }
388 
389 STATISTIC(TruncElemNum, "Number of truncation eliminated");
390 
391 namespace {
392 
393 struct BPFMIPeepholeTruncElim : public MachineFunctionPass {
394 
395   static char ID;
396   const BPFInstrInfo *TII;
397   MachineFunction *MF;
398   MachineRegisterInfo *MRI;
399 
400   BPFMIPeepholeTruncElim() : MachineFunctionPass(ID) {
401     initializeBPFMIPeepholeTruncElimPass(*PassRegistry::getPassRegistry());
402   }
403 
404 private:
405   // Initialize class variables.
406   void initialize(MachineFunction &MFParm);
407 
408   bool eliminateTruncSeq();
409 
410 public:
411 
412   // Main entry point for this pass.
413   bool runOnMachineFunction(MachineFunction &MF) override {
414     if (skipFunction(MF.getFunction()))
415       return false;
416 
417     initialize(MF);
418 
419     return eliminateTruncSeq();
420   }
421 };
422 
423 static bool TruncSizeCompatible(int TruncSize, unsigned opcode)
424 {
425   if (TruncSize == 1)
426     return opcode == BPF::LDB || opcode == BPF::LDB32;
427 
428   if (TruncSize == 2)
429     return opcode == BPF::LDH || opcode == BPF::LDH32;
430 
431   if (TruncSize == 4)
432     return opcode == BPF::LDW || opcode == BPF::LDW32;
433 
434   return false;
435 }
436 
437 // Initialize class variables.
438 void BPFMIPeepholeTruncElim::initialize(MachineFunction &MFParm) {
439   MF = &MFParm;
440   MRI = &MF->getRegInfo();
441   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
442   LLVM_DEBUG(dbgs() << "*** BPF MachineSSA TRUNC Elim peephole pass ***\n\n");
443 }
444 
445 // Reg truncating is often the result of 8/16/32bit->64bit or
446 // 8/16bit->32bit conversion. If the reg value is loaded with
447 // masked byte width, the AND operation can be removed since
448 // BPF LOAD already has zero extension.
449 //
450 // This also solved a correctness issue.
451 // In BPF socket-related program, e.g., __sk_buff->{data, data_end}
452 // are 32-bit registers, but later on, kernel verifier will rewrite
453 // it with 64-bit value. Therefore, truncating the value after the
454 // load will result in incorrect code.
455 bool BPFMIPeepholeTruncElim::eliminateTruncSeq() {
456   MachineInstr* ToErase = nullptr;
457   bool Eliminated = false;
458 
459   for (MachineBasicBlock &MBB : *MF) {
460     for (MachineInstr &MI : MBB) {
461       // The second insn to remove if the eliminate candidate is a pair.
462       MachineInstr *MI2 = nullptr;
463       Register DstReg, SrcReg;
464       MachineInstr *DefMI;
465       int TruncSize = -1;
466 
467       // If the previous instruction was marked for elimination, remove it now.
468       if (ToErase) {
469         ToErase->eraseFromParent();
470         ToErase = nullptr;
471       }
472 
473       // AND A, 0xFFFFFFFF will be turned into SLL/SRL pair due to immediate
474       // for BPF ANDI is i32, and this case only happens on ALU64.
475       if (MI.getOpcode() == BPF::SRL_ri &&
476           MI.getOperand(2).getImm() == 32) {
477         SrcReg = MI.getOperand(1).getReg();
478         if (!MRI->hasOneNonDBGUse(SrcReg))
479           continue;
480 
481         MI2 = MRI->getVRegDef(SrcReg);
482         DstReg = MI.getOperand(0).getReg();
483 
484         if (!MI2 ||
485             MI2->getOpcode() != BPF::SLL_ri ||
486             MI2->getOperand(2).getImm() != 32)
487           continue;
488 
489         // Update SrcReg.
490         SrcReg = MI2->getOperand(1).getReg();
491         DefMI = MRI->getVRegDef(SrcReg);
492         if (DefMI)
493           TruncSize = 4;
494       } else if (MI.getOpcode() == BPF::AND_ri ||
495                  MI.getOpcode() == BPF::AND_ri_32) {
496         SrcReg = MI.getOperand(1).getReg();
497         DstReg = MI.getOperand(0).getReg();
498         DefMI = MRI->getVRegDef(SrcReg);
499 
500         if (!DefMI)
501           continue;
502 
503         int64_t imm = MI.getOperand(2).getImm();
504         if (imm == 0xff)
505           TruncSize = 1;
506         else if (imm == 0xffff)
507           TruncSize = 2;
508       }
509 
510       if (TruncSize == -1)
511         continue;
512 
513       // The definition is PHI node, check all inputs.
514       if (DefMI->isPHI()) {
515         bool CheckFail = false;
516 
517         for (unsigned i = 1, e = DefMI->getNumOperands(); i < e; i += 2) {
518           MachineOperand &opnd = DefMI->getOperand(i);
519           if (!opnd.isReg()) {
520             CheckFail = true;
521             break;
522           }
523 
524           MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
525           if (!PhiDef || PhiDef->isPHI() ||
526               !TruncSizeCompatible(TruncSize, PhiDef->getOpcode())) {
527             CheckFail = true;
528             break;
529           }
530         }
531 
532         if (CheckFail)
533           continue;
534       } else if (!TruncSizeCompatible(TruncSize, DefMI->getOpcode())) {
535         continue;
536       }
537 
538       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::MOV_rr), DstReg)
539               .addReg(SrcReg);
540 
541       if (MI2)
542         MI2->eraseFromParent();
543 
544       // Mark it to ToErase, and erase in the next iteration.
545       ToErase = &MI;
546       TruncElemNum++;
547       Eliminated = true;
548     }
549   }
550 
551   return Eliminated;
552 }
553 
554 } // end default namespace
555 
556 INITIALIZE_PASS(BPFMIPeepholeTruncElim, "bpf-mi-trunc-elim",
557                 "BPF MachineSSA Peephole Optimization For TRUNC Eliminate",
558                 false, false)
559 
560 char BPFMIPeepholeTruncElim::ID = 0;
561 FunctionPass* llvm::createBPFMIPeepholeTruncElimPass()
562 {
563   return new BPFMIPeepholeTruncElim();
564 }
565