1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 /// A note on VPR.P0 (the lane mask):
39 /// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a
40 /// "VPT Active" context (which includes low-overhead loops and vpt blocks).
41 /// They will simply "and" the result of their calculation with the current
42 /// value of VPR.P0. You can think of it like this:
43 /// \verbatim
44 /// if VPT active:    ; Between a DLSTP/LETP, or for predicated instrs
45 ///   VPR.P0 &= Value
46 /// else
47 ///   VPR.P0 = Value
48 /// \endverbatim
49 /// When we're inside the low-overhead loop (between DLSTP and LETP), we always
50 /// fall in the "VPT active" case, so we can consider that all VPR writes by
51 /// one of those instruction is actually a "and".
52 //===----------------------------------------------------------------------===//
53 
54 #include "ARM.h"
55 #include "ARMBaseInstrInfo.h"
56 #include "ARMBaseRegisterInfo.h"
57 #include "ARMBasicBlockInfo.h"
58 #include "ARMSubtarget.h"
59 #include "Thumb2InstrInfo.h"
60 #include "llvm/ADT/SetOperations.h"
61 #include "llvm/ADT/SmallSet.h"
62 #include "llvm/CodeGen/LivePhysRegs.h"
63 #include "llvm/CodeGen/MachineFunctionPass.h"
64 #include "llvm/CodeGen/MachineLoopInfo.h"
65 #include "llvm/CodeGen/MachineLoopUtils.h"
66 #include "llvm/CodeGen/MachineRegisterInfo.h"
67 #include "llvm/CodeGen/Passes.h"
68 #include "llvm/CodeGen/ReachingDefAnalysis.h"
69 #include "llvm/MC/MCInstrDesc.h"
70 
71 using namespace llvm;
72 
73 #define DEBUG_TYPE "arm-low-overhead-loops"
74 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
75 
76 namespace {
77 
78   using InstSet = SmallPtrSetImpl<MachineInstr *>;
79 
80   class PostOrderLoopTraversal {
81     MachineLoop &ML;
82     MachineLoopInfo &MLI;
83     SmallPtrSet<MachineBasicBlock*, 4> Visited;
84     SmallVector<MachineBasicBlock*, 4> Order;
85 
86   public:
PostOrderLoopTraversal(MachineLoop & ML,MachineLoopInfo & MLI)87     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
88       : ML(ML), MLI(MLI) { }
89 
getOrder() const90     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
91       return Order;
92     }
93 
94     // Visit all the blocks within the loop, as well as exit blocks and any
95     // blocks properly dominating the header.
ProcessLoop()96     void ProcessLoop() {
97       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
98         (MachineBasicBlock *MBB) -> void {
99         if (Visited.count(MBB))
100           return;
101 
102         Visited.insert(MBB);
103         for (auto *Succ : MBB->successors()) {
104           if (!ML.contains(Succ))
105             continue;
106           Search(Succ);
107         }
108         Order.push_back(MBB);
109       };
110 
111       // Insert exit blocks.
112       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
113       ML.getExitBlocks(ExitBlocks);
114       for (auto *MBB : ExitBlocks)
115         Order.push_back(MBB);
116 
117       // Then add the loop body.
118       Search(ML.getHeader());
119 
120       // Then try the preheader and its predecessors.
121       std::function<void(MachineBasicBlock*)> GetPredecessor =
122         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
123         Order.push_back(MBB);
124         if (MBB->pred_size() == 1)
125           GetPredecessor(*MBB->pred_begin());
126       };
127 
128       if (auto *Preheader = ML.getLoopPreheader())
129         GetPredecessor(Preheader);
130       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
131         GetPredecessor(Preheader);
132     }
133   };
134 
135   struct PredicatedMI {
136     MachineInstr *MI = nullptr;
137     SetVector<MachineInstr*> Predicates;
138 
139   public:
PredicatedMI__anonf0e41e8a0111::PredicatedMI140     PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) {
141       assert(I && "Instruction must not be null!");
142       Predicates.insert(Preds.begin(), Preds.end());
143     }
144   };
145 
146   // Represent a VPT block, a list of instructions that begins with a VPT/VPST
147   // and has a maximum of four proceeding instructions. All instructions within
148   // the block are predicated upon the vpr and we allow instructions to define
149   // the vpr within in the block too.
150   class VPTBlock {
151     // The predicate then instruction, which is either a VPT, or a VPST
152     // instruction.
153     std::unique_ptr<PredicatedMI> PredicateThen;
154     PredicatedMI *Divergent = nullptr;
155     SmallVector<PredicatedMI, 4> Insts;
156 
157   public:
VPTBlock(MachineInstr * MI,SetVector<MachineInstr * > & Preds)158     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
159       PredicateThen = std::make_unique<PredicatedMI>(MI, Preds);
160     }
161 
addInst(MachineInstr * MI,SetVector<MachineInstr * > & Preds)162     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
163       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
164       if (!Divergent && !set_difference(Preds, PredicateThen->Predicates).empty()) {
165         Divergent = &Insts.back();
166         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
167       }
168       Insts.emplace_back(MI, Preds);
169       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
170     }
171 
172     // Have we found an instruction within the block which defines the vpr? If
173     // so, not all the instructions in the block will have the same predicate.
HasNonUniformPredicate() const174     bool HasNonUniformPredicate() const {
175       return Divergent != nullptr;
176     }
177 
178     // Is the given instruction part of the predicate set controlling the entry
179     // to the block.
IsPredicatedOn(MachineInstr * MI) const180     bool IsPredicatedOn(MachineInstr *MI) const {
181       return PredicateThen->Predicates.count(MI);
182     }
183 
184     // Returns true if this is a VPT instruction.
isVPT() const185     bool isVPT() const { return !isVPST(); }
186 
187     // Returns true if this is a VPST instruction.
isVPST() const188     bool isVPST() const {
189       return PredicateThen->MI->getOpcode() == ARM::MVE_VPST;
190     }
191 
192     // Is the given instruction the only predicate which controls the entry to
193     // the block.
IsOnlyPredicatedOn(MachineInstr * MI) const194     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
195       return IsPredicatedOn(MI) && PredicateThen->Predicates.size() == 1;
196     }
197 
size() const198     unsigned size() const { return Insts.size(); }
getInsts()199     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
getPredicateThen() const200     MachineInstr *getPredicateThen() const { return PredicateThen->MI; }
getDivergent() const201     PredicatedMI *getDivergent() const { return Divergent; }
202   };
203 
204   struct Reduction {
205     MachineInstr *Init;
206     MachineInstr &Copy;
207     MachineInstr &Reduce;
208     MachineInstr &VPSEL;
209 
Reduction__anonf0e41e8a0111::Reduction210     Reduction(MachineInstr *Init, MachineInstr *Mov, MachineInstr *Add,
211               MachineInstr *Sel)
212       : Init(Init), Copy(*Mov), Reduce(*Add), VPSEL(*Sel) { }
213   };
214 
215   struct LowOverheadLoop {
216 
217     MachineLoop &ML;
218     MachineBasicBlock *Preheader = nullptr;
219     MachineLoopInfo &MLI;
220     ReachingDefAnalysis &RDA;
221     const TargetRegisterInfo &TRI;
222     const ARMBaseInstrInfo &TII;
223     MachineFunction *MF = nullptr;
224     MachineInstr *InsertPt = nullptr;
225     MachineInstr *Start = nullptr;
226     MachineInstr *Dec = nullptr;
227     MachineInstr *End = nullptr;
228     MachineInstr *VCTP = nullptr;
229     SmallPtrSet<MachineInstr*, 4> SecondaryVCTPs;
230     VPTBlock *CurrentBlock = nullptr;
231     SetVector<MachineInstr*> CurrentPredicate;
232     SmallVector<VPTBlock, 4> VPTBlocks;
233     SmallPtrSet<MachineInstr*, 4> ToRemove;
234     SmallVector<std::unique_ptr<Reduction>, 1> Reductions;
235     SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute;
236     bool Revert = false;
237     bool CannotTailPredicate = false;
238 
LowOverheadLoop__anonf0e41e8a0111::LowOverheadLoop239     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
240                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI,
241                     const ARMBaseInstrInfo &TII)
242       : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII) {
243       MF = ML.getHeader()->getParent();
244       if (auto *MBB = ML.getLoopPreheader())
245         Preheader = MBB;
246       else if (auto *MBB = MLI.findLoopPreheader(&ML, true))
247         Preheader = MBB;
248     }
249 
250     // If this is an MVE instruction, check that we know how to use tail
251     // predication with it. Record VPT blocks and return whether the
252     // instruction is valid for tail predication.
253     bool ValidateMVEInst(MachineInstr *MI);
254 
AnalyseMVEInst__anonf0e41e8a0111::LowOverheadLoop255     void AnalyseMVEInst(MachineInstr *MI) {
256       CannotTailPredicate = !ValidateMVEInst(MI);
257     }
258 
IsTailPredicationLegal__anonf0e41e8a0111::LowOverheadLoop259     bool IsTailPredicationLegal() const {
260       // For now, let's keep things really simple and only support a single
261       // block for tail predication.
262       return !Revert && FoundAllComponents() && VCTP &&
263              !CannotTailPredicate && ML.getNumBlocks() == 1;
264     }
265 
266     // Check that the predication in the loop will be equivalent once we
267     // perform the conversion. Also ensure that we can provide the number
268     // of elements to the loop start instruction.
269     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
270 
271     // See whether the live-out instructions are a reduction that we can fixup
272     // later.
273     bool FindValidReduction(InstSet &LiveMIs, InstSet &LiveOutUsers);
274 
275     // Check that any values available outside of the loop will be the same
276     // after tail predication conversion.
277     bool ValidateLiveOuts();
278 
279     // Is it safe to define LR with DLS/WLS?
280     // LR can be defined if it is the operand to start, because it's the same
281     // value, or if it's going to be equivalent to the operand to Start.
282     MachineInstr *isSafeToDefineLR();
283 
284     // Check the branch targets are within range and we satisfy our
285     // restrictions.
286     void CheckLegality(ARMBasicBlockUtils *BBUtils);
287 
FoundAllComponents__anonf0e41e8a0111::LowOverheadLoop288     bool FoundAllComponents() const {
289       return Start && Dec && End;
290     }
291 
getVPTBlocks__anonf0e41e8a0111::LowOverheadLoop292     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
293 
294     // Return the loop iteration count, or the number of elements if we're tail
295     // predicating.
getCount__anonf0e41e8a0111::LowOverheadLoop296     MachineOperand &getCount() {
297       return IsTailPredicationLegal() ?
298         VCTP->getOperand(1) : Start->getOperand(0);
299     }
300 
getStartOpcode__anonf0e41e8a0111::LowOverheadLoop301     unsigned getStartOpcode() const {
302       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
303       if (!IsTailPredicationLegal())
304         return IsDo ? ARM::t2DLS : ARM::t2WLS;
305 
306       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
307     }
308 
dump__anonf0e41e8a0111::LowOverheadLoop309     void dump() const {
310       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
311       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
312       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
313       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
314       if (!FoundAllComponents())
315         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
316       else if (!(Start && Dec && End))
317         dbgs() << "ARM Loops: Failed to find all loop components.\n";
318     }
319   };
320 
321   class ARMLowOverheadLoops : public MachineFunctionPass {
322     MachineFunction           *MF = nullptr;
323     MachineLoopInfo           *MLI = nullptr;
324     ReachingDefAnalysis       *RDA = nullptr;
325     const ARMBaseInstrInfo    *TII = nullptr;
326     MachineRegisterInfo       *MRI = nullptr;
327     const TargetRegisterInfo  *TRI = nullptr;
328     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
329 
330   public:
331     static char ID;
332 
ARMLowOverheadLoops()333     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
334 
getAnalysisUsage(AnalysisUsage & AU) const335     void getAnalysisUsage(AnalysisUsage &AU) const override {
336       AU.setPreservesCFG();
337       AU.addRequired<MachineLoopInfo>();
338       AU.addRequired<ReachingDefAnalysis>();
339       MachineFunctionPass::getAnalysisUsage(AU);
340     }
341 
342     bool runOnMachineFunction(MachineFunction &MF) override;
343 
getRequiredProperties() const344     MachineFunctionProperties getRequiredProperties() const override {
345       return MachineFunctionProperties().set(
346           MachineFunctionProperties::Property::NoVRegs).set(
347           MachineFunctionProperties::Property::TracksLiveness);
348     }
349 
getPassName() const350     StringRef getPassName() const override {
351       return ARM_LOW_OVERHEAD_LOOPS_NAME;
352     }
353 
354   private:
355     bool ProcessLoop(MachineLoop *ML);
356 
357     bool RevertNonLoops();
358 
359     void RevertWhile(MachineInstr *MI) const;
360 
361     bool RevertLoopDec(MachineInstr *MI) const;
362 
363     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
364 
365     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
366 
367     void FixupReductions(LowOverheadLoop &LoLoop) const;
368 
369     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
370 
371     void Expand(LowOverheadLoop &LoLoop);
372 
373     void IterationCountDCE(LowOverheadLoop &LoLoop);
374   };
375 }
376 
377 char ARMLowOverheadLoops::ID = 0;
378 
INITIALIZE_PASS(ARMLowOverheadLoops,DEBUG_TYPE,ARM_LOW_OVERHEAD_LOOPS_NAME,false,false)379 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
380                 false, false)
381 
382 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
383   // We can define LR because LR already contains the same value.
384   if (Start->getOperand(0).getReg() == ARM::LR)
385     return Start;
386 
387   unsigned CountReg = Start->getOperand(0).getReg();
388   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
389     return MI->getOpcode() == ARM::tMOVr &&
390            MI->getOperand(0).getReg() == ARM::LR &&
391            MI->getOperand(1).getReg() == CountReg &&
392            MI->getOperand(2).getImm() == ARMCC::AL;
393    };
394 
395   MachineBasicBlock *MBB = Start->getParent();
396 
397   // Find an insertion point:
398   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
399   //   to Count before Start, we can insert at that mov.
400   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
401     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
402       return LRDef;
403 
404   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
405   //   to Count after Start, we can insert at that mov.
406   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
407     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
408       return LRDef;
409 
410   // We've found no suitable LR def and Start doesn't use LR directly. Can we
411   // just define LR anyway?
412   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
413 }
414 
ValidateTailPredicate(MachineInstr * StartInsertPt)415 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
416   assert(VCTP && "VCTP instruction expected but is not set");
417   // All predication within the loop should be based on vctp. If the block
418   // isn't predicated on entry, check whether the vctp is within the block
419   // and that all other instructions are then predicated on it.
420   for (auto &Block : VPTBlocks) {
421     if (Block.IsPredicatedOn(VCTP))
422       continue;
423     if (Block.HasNonUniformPredicate() && !isVCTP(Block.getDivergent()->MI)) {
424       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
425                         << *Block.getDivergent()->MI);
426       return false;
427     }
428     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
429     for (auto &PredMI : Insts) {
430       // Check the instructions in the block and only allow:
431       //   - VCTPs
432       //   - Instructions predicated on the main VCTP
433       //   - Any VCMP
434       //      - VCMPs just "and" their result with VPR.P0. Whether they are
435       //      located before/after the VCTP is irrelevant - the end result will
436       //      be the same in both cases, so there's no point in requiring them
437       //      to be located after the VCTP!
438       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI) ||
439           VCMPOpcodeToVPT(PredMI.MI->getOpcode()) != 0)
440         continue;
441       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
442                  << " - which is predicated on:\n";
443                  for (auto *MI : PredMI.Predicates)
444                    dbgs() << "   - " << *MI);
445       return false;
446     }
447   }
448 
449   if (!ValidateLiveOuts())
450     return false;
451 
452   // For tail predication, we need to provide the number of elements, instead
453   // of the iteration count, to the loop start instruction. The number of
454   // elements is provided to the vctp instruction, so we need to check that
455   // we can use this register at InsertPt.
456   Register NumElements = VCTP->getOperand(1).getReg();
457 
458   // If the register is defined within loop, then we can't perform TP.
459   // TODO: Check whether this is just a mov of a register that would be
460   // available.
461   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
462     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
463     return false;
464   }
465 
466   // The element count register maybe defined after InsertPt, in which case we
467   // need to try to move either InsertPt or the def so that the [w|d]lstp can
468   // use the value.
469   // TODO: On failing to move an instruction, check if the count is provided by
470   // a mov and whether we can use the mov operand directly.
471   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
472   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
473     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
474       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
475         ElemDef->removeFromParent();
476         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
477         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
478                    << *ElemDef);
479       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
480         StartInsertPt->removeFromParent();
481         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
482                               StartInsertPt);
483         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
484       } else {
485         LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop "
486                    << "start instruction.\n");
487         return false;
488       }
489     }
490   }
491 
492   // Especially in the case of while loops, InsertBB may not be the
493   // preheader, so we need to check that the register isn't redefined
494   // before entering the loop.
495   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
496                                       Register NumElements) {
497     // NumElements is redefined in this block.
498     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
499       return true;
500 
501     // Don't continue searching up through multiple predecessors.
502     if (MBB->pred_size() > 1)
503       return true;
504 
505     return false;
506   };
507 
508   // First, find the block that looks like the preheader.
509   MachineBasicBlock *MBB = Preheader;
510   if (!MBB) {
511     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
512     return false;
513   }
514 
515   // Then search backwards for a def, until we get to InsertBB.
516   while (MBB != InsertBB) {
517     if (CannotProvideElements(MBB, NumElements)) {
518       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
519       return false;
520     }
521     MBB = *MBB->pred_begin();
522   }
523 
524   // Check that the value change of the element count is what we expect and
525   // that the predication will be equivalent. For this we need:
526   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
527   // and we can also allow register copies within the chain too.
528   auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
529     return -getAddSubImmediate(*MI) == ExpectedVecWidth;
530   };
531 
532   MBB = VCTP->getParent();
533   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) {
534     SmallPtrSet<MachineInstr*, 2> ElementChain;
535     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
536     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
537 
538     Ignore.insert(SecondaryVCTPs.begin(), SecondaryVCTPs.end());
539 
540     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
541       bool FoundSub = false;
542 
543       for (auto *MI : ElementChain) {
544         if (isMovRegOpcode(MI->getOpcode()))
545           continue;
546 
547         if (isSubImmOpcode(MI->getOpcode())) {
548           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
549             return false;
550           FoundSub = true;
551         } else
552           return false;
553       }
554 
555       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
556                  for (auto *MI : ElementChain)
557                    dbgs() << " - " << *MI);
558       ToRemove.insert(ElementChain.begin(), ElementChain.end());
559     }
560   }
561   return true;
562 }
563 
isVectorPredicated(MachineInstr * MI)564 static bool isVectorPredicated(MachineInstr *MI) {
565   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
566   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
567 }
568 
isRegInClass(const MachineOperand & MO,const TargetRegisterClass * Class)569 static bool isRegInClass(const MachineOperand &MO,
570                          const TargetRegisterClass *Class) {
571   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
572 }
573 
574 // MVE 'narrowing' operate on half a lane, reading from half and writing
575 // to half, which are referred to has the top and bottom half. The other
576 // half retains its previous value.
retainsPreviousHalfElement(const MachineInstr & MI)577 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
578   const MCInstrDesc &MCID = MI.getDesc();
579   uint64_t Flags = MCID.TSFlags;
580   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
581 }
582 
583 // Some MVE instructions read from the top/bottom halves of their operand(s)
584 // and generate a vector result with result elements that are double the
585 // width of the input.
producesDoubleWidthResult(const MachineInstr & MI)586 static bool producesDoubleWidthResult(const MachineInstr &MI) {
587   const MCInstrDesc &MCID = MI.getDesc();
588   uint64_t Flags = MCID.TSFlags;
589   return (Flags & ARMII::DoubleWidthResult) != 0;
590 }
591 
isHorizontalReduction(const MachineInstr & MI)592 static bool isHorizontalReduction(const MachineInstr &MI) {
593   const MCInstrDesc &MCID = MI.getDesc();
594   uint64_t Flags = MCID.TSFlags;
595   return (Flags & ARMII::HorizontalReduction) != 0;
596 }
597 
598 // Can this instruction generate a non-zero result when given only zeroed
599 // operands? This allows us to know that, given operands with false bytes
600 // zeroed by masked loads, that the result will also contain zeros in those
601 // bytes.
canGenerateNonZeros(const MachineInstr & MI)602 static bool canGenerateNonZeros(const MachineInstr &MI) {
603 
604   // Check for instructions which can write into a larger element size,
605   // possibly writing into a previous zero'd lane.
606   if (producesDoubleWidthResult(MI))
607     return true;
608 
609   switch (MI.getOpcode()) {
610   default:
611     break;
612   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
613   // fp16 -> fp32 vector conversions.
614   // Instructions that perform a NOT will generate 1s from 0s.
615   case ARM::MVE_VMVN:
616   case ARM::MVE_VORN:
617   // Count leading zeros will do just that!
618   case ARM::MVE_VCLZs8:
619   case ARM::MVE_VCLZs16:
620   case ARM::MVE_VCLZs32:
621     return true;
622   }
623   return false;
624 }
625 
626 
627 // Look at its register uses to see if it only can only receive zeros
628 // into its false lanes which would then produce zeros. Also check that
629 // the output register is also defined by an FalseLanesZero instruction
630 // so that if tail-predication happens, the lanes that aren't updated will
631 // still be zeros.
producesFalseLanesZero(MachineInstr & MI,const TargetRegisterClass * QPRs,const ReachingDefAnalysis & RDA,InstSet & FalseLanesZero)632 static bool producesFalseLanesZero(MachineInstr &MI,
633                                    const TargetRegisterClass *QPRs,
634                                    const ReachingDefAnalysis &RDA,
635                                    InstSet &FalseLanesZero) {
636   if (canGenerateNonZeros(MI))
637     return false;
638 
639   bool AllowScalars = isHorizontalReduction(MI);
640   for (auto &MO : MI.operands()) {
641     if (!MO.isReg() || !MO.getReg())
642       continue;
643     if (!isRegInClass(MO, QPRs) && AllowScalars)
644       continue;
645     if (auto *OpDef = RDA.getMIOperand(&MI, MO))
646       if (FalseLanesZero.count(OpDef))
647        continue;
648     return false;
649   }
650   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
651   return true;
652 }
653 
654 bool
FindValidReduction(InstSet & LiveMIs,InstSet & LiveOutUsers)655 LowOverheadLoop::FindValidReduction(InstSet &LiveMIs, InstSet &LiveOutUsers) {
656   // Also check for reductions where the operation needs to be merging values
657   // from the last and previous loop iterations. This means an instruction
658   // producing a value and a vmov storing the value calculated in the previous
659   // iteration. So we can have two live-out regs, one produced by a vmov and
660   // both being consumed by a vpsel.
661   LLVM_DEBUG(dbgs() << "ARM Loops: Looking for reduction live-outs:\n";
662              for (auto *MI : LiveMIs)
663                dbgs() << " - " << *MI);
664 
665   if (!Preheader)
666     return false;
667 
668   // Expect a vmov, a vadd and a single vpsel user.
669   // TODO: This means we can't currently support multiple reductions in the
670   // loop.
671   if (LiveMIs.size() != 2 || LiveOutUsers.size() != 1)
672     return false;
673 
674   MachineInstr *VPSEL = *LiveOutUsers.begin();
675   if (VPSEL->getOpcode() != ARM::MVE_VPSEL)
676     return false;
677 
678   unsigned VPRIdx = llvm::findFirstVPTPredOperandIdx(*VPSEL) + 1;
679   MachineInstr *Pred = RDA.getMIOperand(VPSEL, VPRIdx);
680   if (!Pred || Pred != VCTP) {
681     LLVM_DEBUG(dbgs() << "ARM Loops: Not using equivalent predicate.\n");
682     return false;
683   }
684 
685   MachineInstr *Reduce = RDA.getMIOperand(VPSEL, 1);
686   if (!Reduce)
687     return false;
688 
689   assert(LiveMIs.count(Reduce) && "Expected MI to be live-out");
690 
691   // TODO: Support more operations than VADD.
692   switch (VCTP->getOpcode()) {
693   default:
694     return false;
695   case ARM::MVE_VCTP8:
696     if (Reduce->getOpcode() != ARM::MVE_VADDi8)
697       return false;
698     break;
699   case ARM::MVE_VCTP16:
700     if (Reduce->getOpcode() != ARM::MVE_VADDi16)
701       return false;
702     break;
703   case ARM::MVE_VCTP32:
704     if (Reduce->getOpcode() != ARM::MVE_VADDi32)
705       return false;
706     break;
707   }
708 
709   // Test that the reduce op is overwriting ones of its operands.
710   if (Reduce->getOperand(0).getReg() != Reduce->getOperand(1).getReg() &&
711       Reduce->getOperand(0).getReg() != Reduce->getOperand(2).getReg()) {
712     LLVM_DEBUG(dbgs() << "ARM Loops: Reducing op isn't overwriting itself.\n");
713     return false;
714   }
715 
716   // Check that the VORR is actually a VMOV.
717   MachineInstr *Copy = RDA.getMIOperand(VPSEL, 2);
718   if (!Copy || Copy->getOpcode() != ARM::MVE_VORR ||
719       !Copy->getOperand(1).isReg() || !Copy->getOperand(2).isReg() ||
720       Copy->getOperand(1).getReg() != Copy->getOperand(2).getReg())
721     return false;
722 
723   assert(LiveMIs.count(Copy) && "Expected MI to be live-out");
724 
725   // Check that the vadd and vmov are only used by each other and the vpsel.
726   SmallPtrSet<MachineInstr*, 2> CopyUsers;
727   RDA.getGlobalUses(Copy, Copy->getOperand(0).getReg(), CopyUsers);
728   if (CopyUsers.size() > 2 || !CopyUsers.count(Reduce)) {
729     LLVM_DEBUG(dbgs() << "ARM Loops: Copy users unsupported.\n");
730     return false;
731   }
732 
733   SmallPtrSet<MachineInstr*, 2> ReduceUsers;
734   RDA.getGlobalUses(Reduce, Reduce->getOperand(0).getReg(), ReduceUsers);
735   if (ReduceUsers.size() > 2 || !ReduceUsers.count(Copy)) {
736     LLVM_DEBUG(dbgs() << "ARM Loops: Reduce users unsupported.\n");
737     return false;
738   }
739 
740   // Then find whether there's an instruction initialising the register that
741   // is storing the reduction.
742   SmallPtrSet<MachineInstr*, 2> Incoming;
743   RDA.getLiveOuts(Preheader, Copy->getOperand(1).getReg(), Incoming);
744   if (Incoming.size() > 1)
745     return false;
746 
747   MachineInstr *Init = Incoming.empty() ? nullptr : *Incoming.begin();
748   LLVM_DEBUG(dbgs() << "ARM Loops: Found a reduction:\n"
749              << " - " << *Copy
750              << " - " << *Reduce
751              << " - " << *VPSEL);
752   Reductions.push_back(std::make_unique<Reduction>(Init, Copy, Reduce, VPSEL));
753   return true;
754 }
755 
ValidateLiveOuts()756 bool LowOverheadLoop::ValidateLiveOuts() {
757   // We want to find out if the tail-predicated version of this loop will
758   // produce the same values as the loop in its original form. For this to
759   // be true, the newly inserted implicit predication must not change the
760   // the (observable) results.
761   // We're doing this because many instructions in the loop will not be
762   // predicated and so the conversion from VPT predication to tail-predication
763   // can result in different values being produced; due to the tail-predication
764   // preventing many instructions from updating their falsely predicated
765   // lanes. This analysis assumes that all the instructions perform lane-wise
766   // operations and don't perform any exchanges.
767   // A masked load, whether through VPT or tail predication, will write zeros
768   // to any of the falsely predicated bytes. So, from the loads, we know that
769   // the false lanes are zeroed and here we're trying to track that those false
770   // lanes remain zero, or where they change, the differences are masked away
771   // by their user(s).
772   // All MVE loads and stores have to be predicated, so we know that any load
773   // operands, or stored results are equivalent already. Other explicitly
774   // predicated instructions will perform the same operation in the original
775   // loop and the tail-predicated form too. Because of this, we can insert
776   // loads, stores and other predicated instructions into our Predicated
777   // set and build from there.
778   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
779   SetVector<MachineInstr *> FalseLanesUnknown;
780   SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
781   SmallPtrSet<MachineInstr *, 4> Predicated;
782   MachineBasicBlock *Header = ML.getHeader();
783 
784   for (auto &MI : *Header) {
785     const MCInstrDesc &MCID = MI.getDesc();
786     uint64_t Flags = MCID.TSFlags;
787     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
788       continue;
789 
790     if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode()))
791       continue;
792 
793     // Predicated loads will write zeros to the falsely predicated bytes of the
794     // destination register.
795     if (isVectorPredicated(&MI)) {
796       if (MI.mayLoad())
797         FalseLanesZero.insert(&MI);
798       Predicated.insert(&MI);
799       continue;
800     }
801 
802     if (MI.getNumDefs() == 0)
803       continue;
804 
805     if (!producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero)) {
806       // We require retaining and horizontal operations to operate upon zero'd
807       // false lanes to ensure the conversion doesn't change the output.
808       if (retainsPreviousHalfElement(MI) || isHorizontalReduction(MI))
809         return false;
810       // Otherwise we need to evaluate this instruction later to see whether
811       // unknown false lanes will get masked away by their user(s).
812       FalseLanesUnknown.insert(&MI);
813     } else if (!isHorizontalReduction(MI))
814       FalseLanesZero.insert(&MI);
815   }
816 
817   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
818                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
819     SmallPtrSet<MachineInstr *, 2> Uses;
820     RDA.getGlobalUses(MI, MO.getReg(), Uses);
821     for (auto *Use : Uses) {
822       if (Use != MI && !Predicated.count(Use))
823         return false;
824     }
825     return true;
826   };
827 
828   // Visit the unknowns in reverse so that we can start at the values being
829   // stored and then we can work towards the leaves, hopefully adding more
830   // instructions to Predicated. Successfully terminating the loop means that
831   // all the unknown values have to found to be masked by predicated user(s).
832   // For any unpredicated values, we store them in NonPredicated so that we
833   // can later check whether these form a reduction.
834   SmallPtrSet<MachineInstr*, 2> NonPredicated;
835   for (auto *MI : reverse(FalseLanesUnknown)) {
836     for (auto &MO : MI->operands()) {
837       if (!isRegInClass(MO, QPRs) || !MO.isDef())
838         continue;
839       if (!HasPredicatedUsers(MI, MO, Predicated)) {
840         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
841                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
842         NonPredicated.insert(MI);
843         continue;
844       }
845     }
846     // Any unknown false lanes have been masked away by the user(s).
847     Predicated.insert(MI);
848   }
849 
850   SmallPtrSet<MachineInstr *, 2> LiveOutMIs;
851   SmallPtrSet<MachineInstr*, 2> LiveOutUsers;
852   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
853   ML.getExitBlocks(ExitBlocks);
854   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
855   assert(ExitBlocks.size() == 1 && "Expected a single exit block");
856   MachineBasicBlock *ExitBB = ExitBlocks.front();
857   for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) {
858     // Check Q-regs that are live in the exit blocks. We don't collect scalars
859     // because they won't be affected by lane predication.
860     if (QPRs->contains(RegMask.PhysReg)) {
861       if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg))
862         LiveOutMIs.insert(MI);
863       RDA.getLiveInUses(ExitBB, RegMask.PhysReg, LiveOutUsers);
864     }
865   }
866 
867   // If we have any non-predicated live-outs, they need to be part of a
868   // reduction that we can fixup later. The reduction that the form of an
869   // operation that uses its previous values through a vmov and then a vpsel
870   // resides in the exit blocks to select the final bytes from n and n-1
871   // iterations.
872   if (!NonPredicated.empty() &&
873       !FindValidReduction(NonPredicated, LiveOutUsers))
874     return false;
875 
876   // We've already validated that any VPT predication within the loop will be
877   // equivalent when we perform the predication transformation; so we know that
878   // any VPT predicated instruction is predicated upon VCTP. Any live-out
879   // instruction needs to be predicated, so check this here. The instructions
880   // in NonPredicated have been found to be a reduction that we can ensure its
881   // legality.
882   for (auto *MI : LiveOutMIs)
883     if (!isVectorPredicated(MI) && !NonPredicated.count(MI))
884       return false;
885 
886   return true;
887 }
888 
CheckLegality(ARMBasicBlockUtils * BBUtils)889 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
890   if (Revert)
891     return;
892 
893   if (!End->getOperand(1).isMBB())
894     report_fatal_error("Expected LoopEnd to target basic block");
895 
896   // TODO Maybe there's cases where the target doesn't have to be the header,
897   // but for now be safe and revert.
898   if (End->getOperand(1).getMBB() != ML.getHeader()) {
899     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
900     Revert = true;
901     return;
902   }
903 
904   // The WLS and LE instructions have 12-bits for the label offset. WLS
905   // requires a positive offset, while LE uses negative.
906   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
907       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
908     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
909     Revert = true;
910     return;
911   }
912 
913   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
914       (BBUtils->getOffsetOf(Start) >
915        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
916        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
917     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
918     Revert = true;
919     return;
920   }
921 
922   InsertPt = Revert ? nullptr : isSafeToDefineLR();
923   if (!InsertPt) {
924     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
925     Revert = true;
926     return;
927   } else
928     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
929 
930   if (!IsTailPredicationLegal()) {
931     LLVM_DEBUG(if (!VCTP)
932                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
933                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
934     return;
935   }
936 
937   assert(ML.getBlocks().size() == 1 &&
938          "Shouldn't be processing a loop with more than one block");
939   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
940   LLVM_DEBUG(if (CannotTailPredicate)
941              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
942 }
943 
ValidateMVEInst(MachineInstr * MI)944 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
945   if (CannotTailPredicate)
946     return false;
947 
948   if (isVCTP(MI)) {
949     // If we find another VCTP, check whether it uses the same value as the main VCTP.
950     // If it does, store it in the SecondaryVCTPs set, else refuse it.
951     if (VCTP) {
952       if (!VCTP->getOperand(1).isIdenticalTo(MI->getOperand(1)) ||
953           !RDA.hasSameReachingDef(VCTP, MI, MI->getOperand(1).getReg())) {
954         LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching "
955                              "definition from the main VCTP");
956         return false;
957       }
958       LLVM_DEBUG(dbgs() << "ARM Loops: Found secondary VCTP: " << *MI);
959       SecondaryVCTPs.insert(MI);
960     } else {
961       LLVM_DEBUG(dbgs() << "ARM Loops: Found 'main' VCTP: " << *MI);
962       VCTP = MI;
963     }
964   } else if (isVPTOpcode(MI->getOpcode())) {
965     if (MI->getOpcode() != ARM::MVE_VPST) {
966       assert(MI->findRegisterDefOperandIdx(ARM::VPR) != -1 &&
967              "VPT does not implicitly define VPR?!");
968       CurrentPredicate.insert(MI);
969     }
970 
971     VPTBlocks.emplace_back(MI, CurrentPredicate);
972     CurrentBlock = &VPTBlocks.back();
973     return true;
974   } else if (MI->getOpcode() == ARM::MVE_VPSEL ||
975              MI->getOpcode() == ARM::MVE_VPNOT) {
976     // TODO: Allow VPSEL and VPNOT, we currently cannot because:
977     // 1) It will use the VPR as a predicate operand, but doesn't have to be
978     //    instead a VPT block, which means we can assert while building up
979     //    the VPT block because we don't find another VPT or VPST to being a new
980     //    one.
981     // 2) VPSEL still requires a VPR operand even after tail predicating,
982     //    which means we can't remove it unless there is another
983     //    instruction, such as vcmp, that can provide the VPR def.
984     return false;
985   }
986 
987   bool IsUse = false;
988   bool IsDef = false;
989   const MCInstrDesc &MCID = MI->getDesc();
990   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
991     const MachineOperand &MO = MI->getOperand(i);
992     if (!MO.isReg() || MO.getReg() != ARM::VPR)
993       continue;
994 
995     if (MO.isDef()) {
996       CurrentPredicate.insert(MI);
997       IsDef = true;
998     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
999       CurrentBlock->addInst(MI, CurrentPredicate);
1000       IsUse = true;
1001     } else {
1002       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
1003       return false;
1004     }
1005   }
1006 
1007   // If we find a vpr def that is not already predicated on the vctp, we've
1008   // got disjoint predicates that may not be equivalent when we do the
1009   // conversion.
1010   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
1011     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
1012     return false;
1013   }
1014 
1015   uint64_t Flags = MCID.TSFlags;
1016   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
1017     return true;
1018 
1019   // If we find an instruction that has been marked as not valid for tail
1020   // predication, only allow the instruction if it's contained within a valid
1021   // VPT block.
1022   if ((Flags & ARMII::ValidForTailPredication) == 0 && !IsUse) {
1023     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
1024     return false;
1025   }
1026 
1027   // If the instruction is already explicitly predicated, then the conversion
1028   // will be fine, but ensure that all memory operations are predicated.
1029   return !IsUse && MI->mayLoadOrStore() ? false : true;
1030 }
1031 
runOnMachineFunction(MachineFunction & mf)1032 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
1033   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
1034   if (!ST.hasLOB())
1035     return false;
1036 
1037   MF = &mf;
1038   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
1039 
1040   MLI = &getAnalysis<MachineLoopInfo>();
1041   RDA = &getAnalysis<ReachingDefAnalysis>();
1042   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
1043   MRI = &MF->getRegInfo();
1044   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
1045   TRI = ST.getRegisterInfo();
1046   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
1047   BBUtils->computeAllBlockSizes();
1048   BBUtils->adjustBBOffsetsAfter(&MF->front());
1049 
1050   bool Changed = false;
1051   for (auto ML : *MLI) {
1052     if (!ML->getParentLoop())
1053       Changed |= ProcessLoop(ML);
1054   }
1055   Changed |= RevertNonLoops();
1056   return Changed;
1057 }
1058 
ProcessLoop(MachineLoop * ML)1059 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
1060 
1061   bool Changed = false;
1062 
1063   // Process inner loops first.
1064   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
1065     Changed |= ProcessLoop(*I);
1066 
1067   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
1068              if (auto *Preheader = ML->getLoopPreheader())
1069                dbgs() << " - " << Preheader->getName() << "\n";
1070              else if (auto *Preheader = MLI->findLoopPreheader(ML))
1071                dbgs() << " - " << Preheader->getName() << "\n";
1072              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
1073                dbgs() << " - " << Preheader->getName() << "\n";
1074              for (auto *MBB : ML->getBlocks())
1075                dbgs() << " - " << MBB->getName() << "\n";
1076             );
1077 
1078   // Search the given block for a loop start instruction. If one isn't found,
1079   // and there's only one predecessor block, search that one too.
1080   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
1081     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
1082     for (auto &MI : *MBB) {
1083       if (isLoopStart(MI))
1084         return &MI;
1085     }
1086     if (MBB->pred_size() == 1)
1087       return SearchForStart(*MBB->pred_begin());
1088     return nullptr;
1089   };
1090 
1091   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII);
1092   // Search the preheader for the start intrinsic.
1093   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
1094   // with potentially multiple set.loop.iterations, so we need to enable this.
1095   if (LoLoop.Preheader)
1096     LoLoop.Start = SearchForStart(LoLoop.Preheader);
1097   else
1098     return false;
1099 
1100   // Find the low-overhead loop components and decide whether or not to fall
1101   // back to a normal loop. Also look for a vctp instructions and decide
1102   // whether we can convert that predicate using tail predication.
1103   for (auto *MBB : reverse(ML->getBlocks())) {
1104     for (auto &MI : *MBB) {
1105       if (MI.isDebugValue())
1106         continue;
1107       else if (MI.getOpcode() == ARM::t2LoopDec)
1108         LoLoop.Dec = &MI;
1109       else if (MI.getOpcode() == ARM::t2LoopEnd)
1110         LoLoop.End = &MI;
1111       else if (isLoopStart(MI))
1112         LoLoop.Start = &MI;
1113       else if (MI.getDesc().isCall()) {
1114         // TODO: Though the call will require LE to execute again, does this
1115         // mean we should revert? Always executing LE hopefully should be
1116         // faster than performing a sub,cmp,br or even subs,br.
1117         LoLoop.Revert = true;
1118         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
1119       } else {
1120         // Record VPR defs and build up their corresponding vpt blocks.
1121         // Check we know how to tail predicate any mve instructions.
1122         LoLoop.AnalyseMVEInst(&MI);
1123       }
1124     }
1125   }
1126 
1127   LLVM_DEBUG(LoLoop.dump());
1128   if (!LoLoop.FoundAllComponents()) {
1129     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
1130     return false;
1131   }
1132 
1133   // Check that the only instruction using LoopDec is LoopEnd.
1134   // TODO: Check for copy chains that really have no effect.
1135   SmallPtrSet<MachineInstr*, 2> Uses;
1136   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
1137   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
1138     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
1139     LoLoop.Revert = true;
1140   }
1141   LoLoop.CheckLegality(BBUtils.get());
1142   Expand(LoLoop);
1143   return true;
1144 }
1145 
1146 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
1147 // beq that branches to the exit branch.
1148 // TODO: We could also try to generate a cbz if the value in LR is also in
1149 // another low register.
RevertWhile(MachineInstr * MI) const1150 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
1151   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
1152   MachineBasicBlock *MBB = MI->getParent();
1153   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1154                                     TII->get(ARM::t2CMPri));
1155   MIB.add(MI->getOperand(0));
1156   MIB.addImm(0);
1157   MIB.addImm(ARMCC::AL);
1158   MIB.addReg(ARM::NoRegister);
1159 
1160   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1161   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1162     ARM::tBcc : ARM::t2Bcc;
1163 
1164   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1165   MIB.add(MI->getOperand(1));   // branch target
1166   MIB.addImm(ARMCC::EQ);        // condition code
1167   MIB.addReg(ARM::CPSR);
1168   MI->eraseFromParent();
1169 }
1170 
RevertLoopDec(MachineInstr * MI) const1171 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
1172   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
1173   MachineBasicBlock *MBB = MI->getParent();
1174   SmallPtrSet<MachineInstr*, 1> Ignore;
1175   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
1176     if (I->getOpcode() == ARM::t2LoopEnd) {
1177       Ignore.insert(&*I);
1178       break;
1179     }
1180   }
1181 
1182   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
1183   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
1184 
1185   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1186                                     TII->get(ARM::t2SUBri));
1187   MIB.addDef(ARM::LR);
1188   MIB.add(MI->getOperand(1));
1189   MIB.add(MI->getOperand(2));
1190   MIB.addImm(ARMCC::AL);
1191   MIB.addReg(0);
1192 
1193   if (SetFlags) {
1194     MIB.addReg(ARM::CPSR);
1195     MIB->getOperand(5).setIsDef(true);
1196   } else
1197     MIB.addReg(0);
1198 
1199   MI->eraseFromParent();
1200   return SetFlags;
1201 }
1202 
1203 // Generate a subs, or sub and cmp, and a branch instead of an LE.
RevertLoopEnd(MachineInstr * MI,bool SkipCmp) const1204 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1205   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1206 
1207   MachineBasicBlock *MBB = MI->getParent();
1208   // Create cmp
1209   if (!SkipCmp) {
1210     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1211                                       TII->get(ARM::t2CMPri));
1212     MIB.addReg(ARM::LR);
1213     MIB.addImm(0);
1214     MIB.addImm(ARMCC::AL);
1215     MIB.addReg(ARM::NoRegister);
1216   }
1217 
1218   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1219   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1220     ARM::tBcc : ARM::t2Bcc;
1221 
1222   // Create bne
1223   MachineInstrBuilder MIB =
1224     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1225   MIB.add(MI->getOperand(1));   // branch target
1226   MIB.addImm(ARMCC::NE);        // condition code
1227   MIB.addReg(ARM::CPSR);
1228   MI->eraseFromParent();
1229 }
1230 
1231 // Perform dead code elimation on the loop iteration count setup expression.
1232 // If we are tail-predicating, the number of elements to be processed is the
1233 // operand of the VCTP instruction in the vector body, see getCount(), which is
1234 // register $r3 in this example:
1235 //
1236 //   $lr = big-itercount-expression
1237 //   ..
1238 //   t2DoLoopStart renamable $lr
1239 //   vector.body:
1240 //     ..
1241 //     $vpr = MVE_VCTP32 renamable $r3
1242 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1243 //     t2LoopEnd renamable $lr, %vector.body
1244 //     tB %end
1245 //
1246 // What we would like achieve here is to replace the do-loop start pseudo
1247 // instruction t2DoLoopStart with:
1248 //
1249 //    $lr = MVE_DLSTP_32 killed renamable $r3
1250 //
1251 // Thus, $r3 which defines the number of elements, is written to $lr,
1252 // and then we want to delete the whole chain that used to define $lr,
1253 // see the comment below how this chain could look like.
1254 //
IterationCountDCE(LowOverheadLoop & LoLoop)1255 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1256   if (!LoLoop.IsTailPredicationLegal())
1257     return;
1258 
1259   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1260 
1261   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1262   if (!Def) {
1263     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1264     return;
1265   }
1266 
1267   // Collect and remove the users of iteration count.
1268   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1269                                             LoLoop.End, LoLoop.InsertPt };
1270   SmallPtrSet<MachineInstr*, 2> Remove;
1271   if (RDA->isSafeToRemove(Def, Remove, Killed))
1272     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1273   else {
1274     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1275     return;
1276   }
1277 
1278   // Collect the dead code and the MBBs in which they reside.
1279   RDA->collectKilledOperands(Def, Killed);
1280   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1281   for (auto *MI : Killed)
1282     BasicBlocks.insert(MI->getParent());
1283 
1284   // Collect IT blocks in all affected basic blocks.
1285   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1286   for (auto *MBB : BasicBlocks) {
1287     for (auto &MI : *MBB) {
1288       if (MI.getOpcode() != ARM::t2IT)
1289         continue;
1290       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1291     }
1292   }
1293 
1294   // If we're removing all of the instructions within an IT block, then
1295   // also remove the IT instruction.
1296   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1297   for (auto *MI : Killed) {
1298     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1299       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1300       auto &CurrentBlock = ITBlocks[IT];
1301       CurrentBlock.erase(MI);
1302       if (CurrentBlock.empty())
1303         ModifiedITs.erase(IT);
1304       else
1305         ModifiedITs.insert(IT);
1306     }
1307   }
1308 
1309   // Delete the killed instructions only if we don't have any IT blocks that
1310   // need to be modified because we need to fixup the mask.
1311   // TODO: Handle cases where IT blocks are modified.
1312   if (ModifiedITs.empty()) {
1313     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1314                for (auto *MI : Killed)
1315                  dbgs() << " - " << *MI);
1316     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1317   } else
1318     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1319 }
1320 
ExpandLoopStart(LowOverheadLoop & LoLoop)1321 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1322   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1323   // When using tail-predication, try to delete the dead code that was used to
1324   // calculate the number of loop iterations.
1325   IterationCountDCE(LoLoop);
1326 
1327   MachineInstr *InsertPt = LoLoop.InsertPt;
1328   MachineInstr *Start = LoLoop.Start;
1329   MachineBasicBlock *MBB = InsertPt->getParent();
1330   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1331   unsigned Opc = LoLoop.getStartOpcode();
1332   MachineOperand &Count = LoLoop.getCount();
1333 
1334   MachineInstrBuilder MIB =
1335     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1336 
1337   MIB.addDef(ARM::LR);
1338   MIB.add(Count);
1339   if (!IsDo)
1340     MIB.add(Start->getOperand(1));
1341 
1342   // If we're inserting at a mov lr, then remove it as it's redundant.
1343   if (InsertPt != Start)
1344     LoLoop.ToRemove.insert(InsertPt);
1345   LoLoop.ToRemove.insert(Start);
1346   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1347   return &*MIB;
1348 }
1349 
FixupReductions(LowOverheadLoop & LoLoop) const1350 void ARMLowOverheadLoops::FixupReductions(LowOverheadLoop &LoLoop) const {
1351   LLVM_DEBUG(dbgs() << "ARM Loops: Fixing up reduction(s).\n");
1352   auto BuildMov = [this](MachineInstr &InsertPt, Register To, Register From) {
1353     MachineBasicBlock *MBB = InsertPt.getParent();
1354     MachineInstrBuilder MIB =
1355       BuildMI(*MBB, &InsertPt, InsertPt.getDebugLoc(), TII->get(ARM::MVE_VORR));
1356     MIB.addDef(To);
1357     MIB.addReg(From);
1358     MIB.addReg(From);
1359     MIB.addImm(0);
1360     MIB.addReg(0);
1361     MIB.addReg(To);
1362     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted VMOV: " << *MIB);
1363   };
1364 
1365   for (auto &Reduction : LoLoop.Reductions) {
1366     MachineInstr &Copy = Reduction->Copy;
1367     MachineInstr &Reduce = Reduction->Reduce;
1368     Register DestReg = Copy.getOperand(0).getReg();
1369 
1370     // Change the initialiser if present
1371     if (Reduction->Init) {
1372       MachineInstr *Init = Reduction->Init;
1373 
1374       for (unsigned i = 0; i < Init->getNumOperands(); ++i) {
1375         MachineOperand &MO = Init->getOperand(i);
1376         if (MO.isReg() && MO.isUse() && MO.isTied() &&
1377             Init->findTiedOperandIdx(i) == 0)
1378           Init->getOperand(i).setReg(DestReg);
1379       }
1380       Init->getOperand(0).setReg(DestReg);
1381       LLVM_DEBUG(dbgs() << "ARM Loops: Changed init regs: " << *Init);
1382     } else
1383       BuildMov(LoLoop.Preheader->instr_back(), DestReg, Copy.getOperand(1).getReg());
1384 
1385     // Change the reducing op to write to the register that is used to copy
1386     // its value on the next iteration. Also update the tied-def operand.
1387     Reduce.getOperand(0).setReg(DestReg);
1388     Reduce.getOperand(5).setReg(DestReg);
1389     LLVM_DEBUG(dbgs() << "ARM Loops: Changed reduction regs: " << Reduce);
1390 
1391     // Instead of a vpsel, just copy the register into the necessary one.
1392     MachineInstr &VPSEL = Reduction->VPSEL;
1393     if (VPSEL.getOperand(0).getReg() != DestReg)
1394       BuildMov(VPSEL, VPSEL.getOperand(0).getReg(), DestReg);
1395 
1396     // Remove the unnecessary instructions.
1397     LLVM_DEBUG(dbgs() << "ARM Loops: Removing:\n"
1398                << " - " << Copy
1399                << " - " << VPSEL << "\n");
1400     Copy.eraseFromParent();
1401     VPSEL.eraseFromParent();
1402   }
1403 }
1404 
ConvertVPTBlocks(LowOverheadLoop & LoLoop)1405 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1406   auto RemovePredicate = [](MachineInstr *MI) {
1407     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1408     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1409       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1410              "Expected Then predicate!");
1411       MI->getOperand(PIdx).setImm(ARMVCC::None);
1412       MI->getOperand(PIdx+1).setReg(0);
1413     } else
1414       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1415   };
1416 
1417   // There are a few scenarios which we have to fix up:
1418   // 1. VPT Blocks with non-uniform predicates:
1419   //    - a. When the divergent instruction is a vctp
1420   //    - b. When the block uses a vpst, and is only predicated on the vctp
1421   //    - c. When the block uses a vpt and (optionally) contains one or more
1422   //         vctp.
1423   // 2. VPT Blocks with uniform predicates:
1424   //    - a. The block uses a vpst, and is only predicated on the vctp
1425   for (auto &Block : LoLoop.getVPTBlocks()) {
1426     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1427     if (Block.HasNonUniformPredicate()) {
1428       PredicatedMI *Divergent = Block.getDivergent();
1429       if (isVCTP(Divergent->MI)) {
1430         // The vctp will be removed, so the block mask of the vp(s)t will need
1431         // to be recomputed.
1432         LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1433       } else if (Block.isVPST() && Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1434         // The VPT block has a non-uniform predicate but it uses a vpst and its
1435         // entry is guarded only by a vctp, which means we:
1436         // - Need to remove the original vpst.
1437         // - Then need to unpredicate any following instructions, until
1438         //   we come across the divergent vpr def.
1439         // - Insert a new vpst to predicate the instruction(s) that following
1440         //   the divergent vpr def.
1441         // TODO: We could be producing more VPT blocks than necessary and could
1442         // fold the newly created one into a proceeding one.
1443         for (auto I = ++MachineBasicBlock::iterator(Block.getPredicateThen()),
1444              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1445           RemovePredicate(&*I);
1446 
1447         unsigned Size = 0;
1448         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1449         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1450         MachineInstr *InsertAt = nullptr;
1451         while (I != E) {
1452           InsertAt = &*I;
1453           ++Size;
1454           ++I;
1455         }
1456         // Create a VPST (with a null mask for now, we'll recompute it later).
1457         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1458                                           InsertAt->getDebugLoc(),
1459                                           TII->get(ARM::MVE_VPST));
1460         MIB.addImm(0);
1461         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
1462         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1463         LoLoop.ToRemove.insert(Block.getPredicateThen());
1464         LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
1465       }
1466       // Else, if the block uses a vpt, iterate over the block, removing the
1467       // extra VCTPs it may contain.
1468       else if (Block.isVPT()) {
1469         bool RemovedVCTP = false;
1470         for (PredicatedMI &Elt : Block.getInsts()) {
1471           MachineInstr *MI = Elt.MI;
1472           if (isVCTP(MI)) {
1473             LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *MI);
1474             LoLoop.ToRemove.insert(MI);
1475             RemovedVCTP = true;
1476             continue;
1477           }
1478         }
1479         if (RemovedVCTP)
1480           LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1481       }
1482     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP) && Block.isVPST()) {
1483       // A vpt block starting with VPST, is only predicated upon vctp and has no
1484       // internal vpr defs:
1485       // - Remove vpst.
1486       // - Unpredicate the remaining instructions.
1487       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
1488       LoLoop.ToRemove.insert(Block.getPredicateThen());
1489       for (auto &PredMI : Insts)
1490         RemovePredicate(PredMI.MI);
1491     }
1492   }
1493   LLVM_DEBUG(dbgs() << "ARM Loops: Removing remaining VCTPs...\n");
1494   // Remove the "main" VCTP
1495   LoLoop.ToRemove.insert(LoLoop.VCTP);
1496   LLVM_DEBUG(dbgs() << "    " << *LoLoop.VCTP);
1497   // Remove remaining secondary VCTPs
1498   for (MachineInstr *VCTP : LoLoop.SecondaryVCTPs) {
1499     // All VCTPs that aren't marked for removal yet should be unpredicated ones.
1500     // The predicated ones should have already been marked for removal when
1501     // visiting the VPT blocks.
1502     if (LoLoop.ToRemove.insert(VCTP).second) {
1503       assert(getVPTInstrPredicate(*VCTP) == ARMVCC::None &&
1504              "Removing Predicated VCTP without updating the block mask!");
1505       LLVM_DEBUG(dbgs() << "    " << *VCTP);
1506     }
1507   }
1508 }
1509 
Expand(LowOverheadLoop & LoLoop)1510 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1511 
1512   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1513   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1514     MachineInstr *End = LoLoop.End;
1515     MachineBasicBlock *MBB = End->getParent();
1516     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1517       ARM::MVE_LETP : ARM::t2LEUpdate;
1518     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1519                                       TII->get(Opc));
1520     MIB.addDef(ARM::LR);
1521     MIB.add(End->getOperand(0));
1522     MIB.add(End->getOperand(1));
1523     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1524     LoLoop.ToRemove.insert(LoLoop.Dec);
1525     LoLoop.ToRemove.insert(End);
1526     return &*MIB;
1527   };
1528 
1529   // TODO: We should be able to automatically remove these branches before we
1530   // get here - probably by teaching analyzeBranch about the pseudo
1531   // instructions.
1532   // If there is an unconditional branch, after I, that just branches to the
1533   // next block, remove it.
1534   auto RemoveDeadBranch = [](MachineInstr *I) {
1535     MachineBasicBlock *BB = I->getParent();
1536     MachineInstr *Terminator = &BB->instr_back();
1537     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1538       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1539       if (BB->isLayoutSuccessor(Succ)) {
1540         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1541         Terminator->eraseFromParent();
1542       }
1543     }
1544   };
1545 
1546   if (LoLoop.Revert) {
1547     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1548       RevertWhile(LoLoop.Start);
1549     else
1550       LoLoop.Start->eraseFromParent();
1551     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1552     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1553   } else {
1554     LoLoop.Start = ExpandLoopStart(LoLoop);
1555     RemoveDeadBranch(LoLoop.Start);
1556     LoLoop.End = ExpandLoopEnd(LoLoop);
1557     RemoveDeadBranch(LoLoop.End);
1558     if (LoLoop.IsTailPredicationLegal()) {
1559       ConvertVPTBlocks(LoLoop);
1560       FixupReductions(LoLoop);
1561     }
1562     for (auto *I : LoLoop.ToRemove) {
1563       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1564       I->eraseFromParent();
1565     }
1566     for (auto *I : LoLoop.BlockMasksToRecompute) {
1567       LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
1568       recomputeVPTBlockMask(*I);
1569       LLVM_DEBUG(dbgs() << "           ... done: " << *I);
1570     }
1571   }
1572 
1573   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1574   DFS.ProcessLoop();
1575   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1576   for (auto *MBB : PostOrder) {
1577     recomputeLiveIns(*MBB);
1578     // FIXME: For some reason, the live-in print order is non-deterministic for
1579     // our tests and I can't out why... So just sort them.
1580     MBB->sortUniqueLiveIns();
1581   }
1582 
1583   for (auto *MBB : reverse(PostOrder))
1584     recomputeLivenessFlags(*MBB);
1585 
1586   // We've moved, removed and inserted new instructions, so update RDA.
1587   RDA->reset();
1588 }
1589 
RevertNonLoops()1590 bool ARMLowOverheadLoops::RevertNonLoops() {
1591   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1592   bool Changed = false;
1593 
1594   for (auto &MBB : *MF) {
1595     SmallVector<MachineInstr*, 4> Starts;
1596     SmallVector<MachineInstr*, 4> Decs;
1597     SmallVector<MachineInstr*, 4> Ends;
1598 
1599     for (auto &I : MBB) {
1600       if (isLoopStart(I))
1601         Starts.push_back(&I);
1602       else if (I.getOpcode() == ARM::t2LoopDec)
1603         Decs.push_back(&I);
1604       else if (I.getOpcode() == ARM::t2LoopEnd)
1605         Ends.push_back(&I);
1606     }
1607 
1608     if (Starts.empty() && Decs.empty() && Ends.empty())
1609       continue;
1610 
1611     Changed = true;
1612 
1613     for (auto *Start : Starts) {
1614       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1615         RevertWhile(Start);
1616       else
1617         Start->eraseFromParent();
1618     }
1619     for (auto *Dec : Decs)
1620       RevertLoopDec(Dec);
1621 
1622     for (auto *End : Ends)
1623       RevertLoopEnd(End);
1624   }
1625   return Changed;
1626 }
1627 
createARMLowOverheadLoopsPass()1628 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1629   return new ARMLowOverheadLoops();
1630 }
1631