1 //===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a pass that recognizes certain loop idioms and
10 // transforms them into more optimized versions of the same loop. In cases
11 // where this happens, it can be a significant performance win.
12 //
13 // We currently only recognize one loop that finds the first mismatched byte
14 // in an array and returns the index, i.e. something like:
15 //
16 //  while (++i != n) {
17 //    if (a[i] != b[i])
18 //      break;
19 //  }
20 //
21 // In this example we can actually vectorize the loop despite the early exit,
22 // although the loop vectorizer does not support it. It requires some extra
23 // checks to deal with the possibility of faulting loads when crossing page
24 // boundaries. However, even with these checks it is still profitable to do the
25 // transformation.
26 //
27 //===----------------------------------------------------------------------===//
28 //
29 // TODO List:
30 //
31 // * Add support for the inverse case where we scan for a matching element.
32 // * Permit 64-bit induction variable types.
33 // * Recognize loops that increment the IV *after* comparing bytes.
34 // * Allow 32-bit sign-extends of the IV used by the GEP.
35 //
36 //===----------------------------------------------------------------------===//
37 
38 #include "AArch64LoopIdiomTransform.h"
39 #include "llvm/Analysis/DomTreeUpdater.h"
40 #include "llvm/Analysis/LoopPass.h"
41 #include "llvm/Analysis/TargetTransformInfo.h"
42 #include "llvm/IR/Dominators.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Intrinsics.h"
45 #include "llvm/IR/MDBuilder.h"
46 #include "llvm/IR/PatternMatch.h"
47 #include "llvm/InitializePasses.h"
48 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
49 
50 using namespace llvm;
51 using namespace PatternMatch;
52 
53 #define DEBUG_TYPE "aarch64-loop-idiom-transform"
54 
55 static cl::opt<bool>
56     DisableAll("disable-aarch64-lit-all", cl::Hidden, cl::init(false),
57                cl::desc("Disable AArch64 Loop Idiom Transform Pass."));
58 
59 static cl::opt<bool> DisableByteCmp(
60     "disable-aarch64-lit-bytecmp", cl::Hidden, cl::init(false),
61     cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do "
62              "not convert byte-compare loop(s)."));
63 
64 static cl::opt<bool> VerifyLoops(
65     "aarch64-lit-verify", cl::Hidden, cl::init(false),
66     cl::desc("Verify loops generated AArch64 Loop Idiom Transform Pass."));
67 
68 namespace llvm {
69 
70 void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &);
71 Pass *createAArch64LoopIdiomTransformPass();
72 
73 } // end namespace llvm
74 
75 namespace {
76 
77 class AArch64LoopIdiomTransform {
78   Loop *CurLoop = nullptr;
79   DominatorTree *DT;
80   LoopInfo *LI;
81   const TargetTransformInfo *TTI;
82   const DataLayout *DL;
83 
84 public:
85   explicit AArch64LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI,
86                                      const TargetTransformInfo *TTI,
87                                      const DataLayout *DL)
88       : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
89 
90   bool run(Loop *L);
91 
92 private:
93   /// \name Countable Loop Idiom Handling
94   /// @{
95 
96   bool runOnCountableLoop();
97   bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
98                       SmallVectorImpl<BasicBlock *> &ExitBlocks);
99 
100   bool recognizeByteCompare();
101   Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
102                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
103                             Instruction *Index, Value *Start, Value *MaxLen);
104   void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
105                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
106                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
107                             BasicBlock *EndBB);
108   /// @}
109 };
110 
111 class AArch64LoopIdiomTransformLegacyPass : public LoopPass {
112 public:
113   static char ID;
114 
115   explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID) {
116     initializeAArch64LoopIdiomTransformLegacyPassPass(
117         *PassRegistry::getPassRegistry());
118   }
119 
120   StringRef getPassName() const override {
121     return "Transform AArch64-specific loop idioms";
122   }
123 
124   void getAnalysisUsage(AnalysisUsage &AU) const override {
125     AU.addRequired<LoopInfoWrapperPass>();
126     AU.addRequired<DominatorTreeWrapperPass>();
127     AU.addRequired<TargetTransformInfoWrapperPass>();
128   }
129 
130   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
131 };
132 
133 bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop *L,
134                                                     LPPassManager &LPM) {
135 
136   if (skipLoop(L))
137     return false;
138 
139   auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
140   auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
141   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
142       *L->getHeader()->getParent());
143   return AArch64LoopIdiomTransform(
144              DT, LI, &TTI, &L->getHeader()->getModule()->getDataLayout())
145       .run(L);
146 }
147 
148 } // end anonymous namespace
149 
150 char AArch64LoopIdiomTransformLegacyPass::ID = 0;
151 
152 INITIALIZE_PASS_BEGIN(
153     AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
154     "Transform specific loop idioms into optimized vector forms", false, false)
155 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
156 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
157 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
158 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
159 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
160 INITIALIZE_PASS_END(
161     AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
162     "Transform specific loop idioms into optimized vector forms", false, false)
163 
164 Pass *llvm::createAArch64LoopIdiomTransformPass() {
165   return new AArch64LoopIdiomTransformLegacyPass();
166 }
167 
168 PreservedAnalyses
169 AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
170                                    LoopStandardAnalysisResults &AR,
171                                    LPMUpdater &) {
172   if (DisableAll)
173     return PreservedAnalyses::all();
174 
175   const auto *DL = &L.getHeader()->getModule()->getDataLayout();
176 
177   AArch64LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
178   if (!LIT.run(&L))
179     return PreservedAnalyses::all();
180 
181   return PreservedAnalyses::none();
182 }
183 
184 //===----------------------------------------------------------------------===//
185 //
186 //          Implementation of AArch64LoopIdiomTransform
187 //
188 //===----------------------------------------------------------------------===//
189 
190 bool AArch64LoopIdiomTransform::run(Loop *L) {
191   CurLoop = L;
192 
193   if (DisableAll || L->getHeader()->getParent()->hasOptSize())
194     return false;
195 
196   // If the loop could not be converted to canonical form, it must have an
197   // indirectbr in it, just give up.
198   if (!L->getLoopPreheader())
199     return false;
200 
201   LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
202                     << CurLoop->getHeader()->getParent()->getName()
203                     << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
204 
205   return recognizeByteCompare();
206 }
207 
208 bool AArch64LoopIdiomTransform::recognizeByteCompare() {
209   // Currently the transformation only works on scalable vector types, although
210   // there is no fundamental reason why it cannot be made to work for fixed
211   // width too.
212 
213   // We also need to know the minimum page size for the target in order to
214   // generate runtime memory checks to ensure the vector version won't fault.
215   if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
216       DisableByteCmp)
217     return false;
218 
219   BasicBlock *Header = CurLoop->getHeader();
220 
221   // In AArch64LoopIdiomTransform::run we have already checked that the loop
222   // has a preheader so we can assume it's in a canonical form.
223   if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
224     return false;
225 
226   PHINode *PN = dyn_cast<PHINode>(&Header->front());
227   if (!PN || PN->getNumIncomingValues() != 2)
228     return false;
229 
230   auto LoopBlocks = CurLoop->getBlocks();
231   // The first block in the loop should contain only 4 instructions, e.g.
232   //
233   //  while.cond:
234   //   %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
235   //   %inc = add i32 %res.phi, 1
236   //   %cmp.not = icmp eq i32 %inc, %n
237   //   br i1 %cmp.not, label %while.end, label %while.body
238   //
239   auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
240   if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4)
241     return false;
242 
243   // The second block should contain 7 instructions, e.g.
244   //
245   // while.body:
246   //   %idx = zext i32 %inc to i64
247   //   %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
248   //   %load.a = load i8, ptr %idx.a
249   //   %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
250   //   %load.b = load i8, ptr %idx.b
251   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
252   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
253   //
254   auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
255   if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
256     return false;
257 
258   // The incoming value to the PHI node from the loop should be an add of 1.
259   Value *StartIdx = nullptr;
260   Instruction *Index = nullptr;
261   if (!CurLoop->contains(PN->getIncomingBlock(0))) {
262     StartIdx = PN->getIncomingValue(0);
263     Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
264   } else {
265     StartIdx = PN->getIncomingValue(1);
266     Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
267   }
268 
269   // Limit to 32-bit types for now
270   if (!Index || !Index->getType()->isIntegerTy(32) ||
271       !match(Index, m_c_Add(m_Specific(PN), m_One())))
272     return false;
273 
274   // If we match the pattern, PN and Index will be replaced with the result of
275   // the cttz.elts intrinsic. If any other instructions are used outside of
276   // the loop, we cannot replace it.
277   for (BasicBlock *BB : LoopBlocks)
278     for (Instruction &I : *BB)
279       if (&I != PN && &I != Index)
280         for (User *U : I.users())
281           if (!CurLoop->contains(cast<Instruction>(U)))
282             return false;
283 
284   // Match the branch instruction for the header
285   ICmpInst::Predicate Pred;
286   Value *MaxLen;
287   BasicBlock *EndBB, *WhileBB;
288   if (!match(Header->getTerminator(),
289              m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
290                   m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
291       Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
292     return false;
293 
294   // WhileBB should contain the pattern of load & compare instructions. Match
295   // the pattern and find the GEP instructions used by the loads.
296   ICmpInst::Predicate WhilePred;
297   BasicBlock *FoundBB;
298   BasicBlock *TrueBB;
299   Value *LoadA, *LoadB;
300   if (!match(WhileBB->getTerminator(),
301              m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
302                   m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
303       WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
304     return false;
305 
306   Value *A, *B;
307   if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
308     return false;
309 
310   LoadInst *LoadAI = cast<LoadInst>(LoadA);
311   LoadInst *LoadBI = cast<LoadInst>(LoadB);
312   if (!LoadAI->isSimple() || !LoadBI->isSimple())
313     return false;
314 
315   GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
316   GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
317 
318   if (!GEPA || !GEPB)
319     return false;
320 
321   Value *PtrA = GEPA->getPointerOperand();
322   Value *PtrB = GEPB->getPointerOperand();
323 
324   // Check we are loading i8 values from two loop invariant pointers
325   if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
326       !GEPA->getResultElementType()->isIntegerTy(8) ||
327       !GEPB->getResultElementType()->isIntegerTy(8) ||
328       !LoadAI->getType()->isIntegerTy(8) ||
329       !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
330     return false;
331 
332   // Check that the index to the GEPs is the index we found earlier
333   if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
334     return false;
335 
336   Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
337   Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
338   if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
339     return false;
340 
341   // We only ever expect the pre-incremented index value to be used inside the
342   // loop.
343   if (!PN->hasOneUse())
344     return false;
345 
346   // Ensure that when the Found and End blocks are identical the PHIs have the
347   // supported format. We don't currently allow cases like this:
348   // while.cond:
349   //   ...
350   //   br i1 %cmp.not, label %while.end, label %while.body
351   //
352   // while.body:
353   //   ...
354   //   br i1 %cmp.not2, label %while.cond, label %while.end
355   //
356   // while.end:
357   //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
358   //
359   // Where the incoming values for %final_ptr are unique and from each of the
360   // loop blocks, but not actually defined in the loop. This requires extra
361   // work setting up the byte.compare block, i.e. by introducing a select to
362   // choose the correct value.
363   // TODO: We could add support for this in future.
364   if (FoundBB == EndBB) {
365     for (PHINode &EndPN : EndBB->phis()) {
366       Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
367       Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
368 
369       // The value of the index when leaving the while.cond block is always the
370       // same as the end value (MaxLen) so we permit either. The value when
371       // leaving the while.body block should only be the index. Otherwise for
372       // any other values we only allow ones that are same for both blocks.
373       if (WhileCondVal != WhileBodyVal &&
374           ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
375            (WhileBodyVal != Index)))
376         return false;
377     }
378   }
379 
380   LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
381                     << *(EndBB->getParent()) << "\n\n");
382 
383   // The index is incremented before the GEP/Load pair so we need to
384   // add 1 to the start value.
385   transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true,
386                        FoundBB, EndBB);
387   return true;
388 }
389 
390 Value *AArch64LoopIdiomTransform::expandFindMismatch(
391     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
392     GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
393   Value *PtrA = GEPA->getPointerOperand();
394   Value *PtrB = GEPB->getPointerOperand();
395 
396   // Get the arguments and types for the intrinsic.
397   BasicBlock *Preheader = CurLoop->getLoopPreheader();
398   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
399   LLVMContext &Ctx = PHBranch->getContext();
400   Type *LoadType = Type::getInt8Ty(Ctx);
401   Type *ResType = Builder.getInt32Ty();
402 
403   // Split block in the original loop preheader.
404   BasicBlock *EndBlock =
405       SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
406 
407   // Create the blocks that we're going to need:
408   //  1. A block for checking the zero-extended length exceeds 0
409   //  2. A block to check that the start and end addresses of a given array
410   //     lie on the same page.
411   //  3. The SVE loop preheader.
412   //  4. The first SVE loop block.
413   //  5. The SVE loop increment block.
414   //  6. A block we can jump to from the SVE loop when a mismatch is found.
415   //  7. The first block of the scalar loop itself, containing PHIs , loads
416   //  and cmp.
417   //  8. A scalar loop increment block to increment the PHIs and go back
418   //  around the loop.
419 
420   BasicBlock *MinItCheckBlock = BasicBlock::Create(
421       Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock);
422 
423   // Update the terminator added by SplitBlock to branch to the first block
424   Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock);
425 
426   BasicBlock *MemCheckBlock = BasicBlock::Create(
427       Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
428 
429   BasicBlock *SVELoopPreheaderBlock = BasicBlock::Create(
430       Ctx, "mismatch_sve_loop_preheader", EndBlock->getParent(), EndBlock);
431 
432   BasicBlock *SVELoopStartBlock = BasicBlock::Create(
433       Ctx, "mismatch_sve_loop", EndBlock->getParent(), EndBlock);
434 
435   BasicBlock *SVELoopIncBlock = BasicBlock::Create(
436       Ctx, "mismatch_sve_loop_inc", EndBlock->getParent(), EndBlock);
437 
438   BasicBlock *SVELoopMismatchBlock = BasicBlock::Create(
439       Ctx, "mismatch_sve_loop_found", EndBlock->getParent(), EndBlock);
440 
441   BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
442       Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
443 
444   BasicBlock *LoopStartBlock =
445       BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock);
446 
447   BasicBlock *LoopIncBlock = BasicBlock::Create(
448       Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock);
449 
450   DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock},
451                     {DominatorTree::Delete, Preheader, EndBlock}});
452 
453   // Update LoopInfo with the new SVE & scalar loops.
454   auto SVELoop = LI->AllocateLoop();
455   auto ScalarLoop = LI->AllocateLoop();
456 
457   if (CurLoop->getParentLoop()) {
458     CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI);
459     CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI);
460     CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopPreheaderBlock, *LI);
461     CurLoop->getParentLoop()->addChildLoop(SVELoop);
462     CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopMismatchBlock, *LI);
463     CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI);
464     CurLoop->getParentLoop()->addChildLoop(ScalarLoop);
465   } else {
466     LI->addTopLevelLoop(SVELoop);
467     LI->addTopLevelLoop(ScalarLoop);
468   }
469 
470   // Add the new basic blocks to their associated loops.
471   SVELoop->addBasicBlockToLoop(SVELoopStartBlock, *LI);
472   SVELoop->addBasicBlockToLoop(SVELoopIncBlock, *LI);
473 
474   ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI);
475   ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI);
476 
477   // Set up some types and constants that we intend to reuse.
478   Type *I64Type = Builder.getInt64Ty();
479 
480   // Check the zero-extended iteration count > 0
481   Builder.SetInsertPoint(MinItCheckBlock);
482   Value *ExtStart = Builder.CreateZExt(Start, I64Type);
483   Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
484   // This check doesn't really cost us very much.
485 
486   Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
487   BranchInst *MinItCheckBr =
488       BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
489   MinItCheckBr->setMetadata(
490       LLVMContext::MD_prof,
491       MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1));
492   Builder.Insert(MinItCheckBr);
493 
494   DTU.applyUpdates(
495       {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
496        {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
497 
498   // For each of the arrays, check the start/end addresses are on the same
499   // page.
500   Builder.SetInsertPoint(MemCheckBlock);
501 
502   // The early exit in the original loop means that when performing vector
503   // loads we are potentially reading ahead of the early exit. So we could
504   // fault if crossing a page boundary. Therefore, we create runtime memory
505   // checks based on the minimum page size as follows:
506   //   1. Calculate the addresses of the first memory accesses in the loop,
507   //      i.e. LhsStart and RhsStart.
508   //   2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
509   //   3. Determine which pages correspond to all the memory accesses, i.e
510   //      LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
511   //   4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
512   //      we know we won't cross any page boundaries in the loop so we can
513   //      enter the vector loop! Otherwise we fall back on the scalar loop.
514   Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart);
515   Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart);
516   Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type);
517   Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type);
518   Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd);
519   Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd);
520   Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type);
521   Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type);
522 
523   const uint64_t MinPageSize = TTI->getMinPageSize().value();
524   const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
525   Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt);
526   Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt);
527   Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt);
528   Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt);
529   Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage);
530   Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage);
531 
532   Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp);
533   BranchInst *CombinedPageCmpCmpBr = BranchInst::Create(
534       LoopPreHeaderBlock, SVELoopPreheaderBlock, CombinedPageCmp);
535   CombinedPageCmpCmpBr->setMetadata(
536       LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext())
537                                 .createBranchWeights(10, 90));
538   Builder.Insert(CombinedPageCmpCmpBr);
539 
540   DTU.applyUpdates(
541       {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock},
542        {DominatorTree::Insert, MemCheckBlock, SVELoopPreheaderBlock}});
543 
544   // Set up the SVE loop preheader, i.e. calculate initial loop predicate,
545   // zero-extend MaxLen to 64-bits, determine the number of vector elements
546   // processed in each iteration, etc.
547   Builder.SetInsertPoint(SVELoopPreheaderBlock);
548 
549   // At this point we know two things must be true:
550   //  1. Start <= End
551   //  2. ExtMaxLen <= MinPageSize due to the page checks.
552   // Therefore, we know that we can use a 64-bit induction variable that
553   // starts from 0 -> ExtMaxLen and it will not overflow.
554   ScalableVectorType *PredVTy =
555       ScalableVectorType::get(Builder.getInt1Ty(), 16);
556 
557   Value *InitialPred = Builder.CreateIntrinsic(
558       Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
559 
560   Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
561   VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
562                              /*HasNUW=*/true, /*HasNSW=*/true);
563 
564   Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
565                                             Builder.getInt1(false));
566 
567   BranchInst *JumpToSVELoop = BranchInst::Create(SVELoopStartBlock);
568   Builder.Insert(JumpToSVELoop);
569 
570   DTU.applyUpdates(
571       {{DominatorTree::Insert, SVELoopPreheaderBlock, SVELoopStartBlock}});
572 
573   // Set up the first SVE loop block by creating the PHIs, doing the vector
574   // loads and comparing the vectors.
575   Builder.SetInsertPoint(SVELoopStartBlock);
576   PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_sve_loop_pred");
577   LoopPred->addIncoming(InitialPred, SVELoopPreheaderBlock);
578   PHINode *SVEIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_sve_index");
579   SVEIndexPhi->addIncoming(ExtStart, SVELoopPreheaderBlock);
580   Type *SVELoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
581   Value *Passthru = ConstantInt::getNullValue(SVELoadType);
582 
583   Value *SVELhsGep = Builder.CreateGEP(LoadType, PtrA, SVEIndexPhi);
584   if (GEPA->isInBounds())
585     cast<GetElementPtrInst>(SVELhsGep)->setIsInBounds(true);
586   Value *SVELhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVELhsGep, Align(1),
587                                                LoopPred, Passthru);
588 
589   Value *SVERhsGep = Builder.CreateGEP(LoadType, PtrB, SVEIndexPhi);
590   if (GEPB->isInBounds())
591     cast<GetElementPtrInst>(SVERhsGep)->setIsInBounds(true);
592   Value *SVERhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVERhsGep, Align(1),
593                                                LoopPred, Passthru);
594 
595   Value *SVEMatchCmp = Builder.CreateICmpNE(SVELhsLoad, SVERhsLoad);
596   SVEMatchCmp = Builder.CreateSelect(LoopPred, SVEMatchCmp, PFalse);
597   Value *SVEMatchHasActiveLanes = Builder.CreateOrReduce(SVEMatchCmp);
598   BranchInst *SVEEarlyExit = BranchInst::Create(
599       SVELoopMismatchBlock, SVELoopIncBlock, SVEMatchHasActiveLanes);
600   Builder.Insert(SVEEarlyExit);
601 
602   DTU.applyUpdates(
603       {{DominatorTree::Insert, SVELoopStartBlock, SVELoopMismatchBlock},
604        {DominatorTree::Insert, SVELoopStartBlock, SVELoopIncBlock}});
605 
606   // Increment the index counter and calculate the predicate for the next
607   // iteration of the loop. We branch back to the start of the loop if there
608   // is at least one active lane.
609   Builder.SetInsertPoint(SVELoopIncBlock);
610   Value *NewSVEIndexPhi = Builder.CreateAdd(SVEIndexPhi, VecLen, "",
611                                             /*HasNUW=*/true, /*HasNSW=*/true);
612   SVEIndexPhi->addIncoming(NewSVEIndexPhi, SVELoopIncBlock);
613   Value *NewPred =
614       Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
615                               {PredVTy, I64Type}, {NewSVEIndexPhi, ExtEnd});
616   LoopPred->addIncoming(NewPred, SVELoopIncBlock);
617 
618   Value *PredHasActiveLanes =
619       Builder.CreateExtractElement(NewPred, uint64_t(0));
620   BranchInst *SVELoopBranchBack =
621       BranchInst::Create(SVELoopStartBlock, EndBlock, PredHasActiveLanes);
622   Builder.Insert(SVELoopBranchBack);
623 
624   DTU.applyUpdates({{DominatorTree::Insert, SVELoopIncBlock, SVELoopStartBlock},
625                     {DominatorTree::Insert, SVELoopIncBlock, EndBlock}});
626 
627   // If we found a mismatch then we need to calculate which lane in the vector
628   // had a mismatch and add that on to the current loop index.
629   Builder.SetInsertPoint(SVELoopMismatchBlock);
630   PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_sve_found_pred");
631   FoundPred->addIncoming(SVEMatchCmp, SVELoopStartBlock);
632   PHINode *LastLoopPred =
633       Builder.CreatePHI(PredVTy, 1, "mismatch_sve_last_loop_pred");
634   LastLoopPred->addIncoming(LoopPred, SVELoopStartBlock);
635   PHINode *SVEFoundIndex =
636       Builder.CreatePHI(I64Type, 1, "mismatch_sve_found_index");
637   SVEFoundIndex->addIncoming(SVEIndexPhi, SVELoopStartBlock);
638 
639   Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
640   Value *Ctz = Builder.CreateIntrinsic(
641       Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
642       {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
643   Ctz = Builder.CreateZExt(Ctz, I64Type);
644   Value *SVELoopRes64 = Builder.CreateAdd(SVEFoundIndex, Ctz, "",
645                                           /*HasNUW=*/true, /*HasNSW=*/true);
646   Value *SVELoopRes = Builder.CreateTrunc(SVELoopRes64, ResType);
647 
648   Builder.Insert(BranchInst::Create(EndBlock));
649 
650   DTU.applyUpdates({{DominatorTree::Insert, SVELoopMismatchBlock, EndBlock}});
651 
652   // Generate code for scalar loop.
653   Builder.SetInsertPoint(LoopPreHeaderBlock);
654   Builder.Insert(BranchInst::Create(LoopStartBlock));
655 
656   DTU.applyUpdates(
657       {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}});
658 
659   Builder.SetInsertPoint(LoopStartBlock);
660   PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index");
661   IndexPhi->addIncoming(Start, LoopPreHeaderBlock);
662 
663   // Otherwise compare the values
664   // Load bytes from each array and compare them.
665   Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type);
666 
667   Value *LhsGep = Builder.CreateGEP(LoadType, PtrA, GepOffset);
668   if (GEPA->isInBounds())
669     cast<GetElementPtrInst>(LhsGep)->setIsInBounds(true);
670   Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep);
671 
672   Value *RhsGep = Builder.CreateGEP(LoadType, PtrB, GepOffset);
673   if (GEPB->isInBounds())
674     cast<GetElementPtrInst>(RhsGep)->setIsInBounds(true);
675   Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep);
676 
677   Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
678   // If we have a mismatch then exit the loop ...
679   BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp);
680   Builder.Insert(MatchCmpBr);
681 
682   DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock},
683                     {DominatorTree::Insert, LoopStartBlock, EndBlock}});
684 
685   // Have we reached the maximum permitted length for the loop?
686   Builder.SetInsertPoint(LoopIncBlock);
687   Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "",
688                                     /*HasNUW=*/Index->hasNoUnsignedWrap(),
689                                     /*HasNSW=*/Index->hasNoSignedWrap());
690   IndexPhi->addIncoming(PhiInc, LoopIncBlock);
691   Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen);
692   BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp);
693   Builder.Insert(IVCmpBr);
694 
695   DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock},
696                     {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}});
697 
698   // In the end block we need to insert a PHI node to deal with three cases:
699   //  1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
700   //  2. We exitted the scalar loop early due to a mismatch and need to return
701   //  the index that we found.
702   //  3. We didn't find a mismatch in the SVE loop, so we return MaxLen.
703   //  4. We exitted the SVE loop early due to a mismatch and need to return
704   //  the index that we found.
705   Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt());
706   PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result");
707   ResPhi->addIncoming(MaxLen, LoopIncBlock);
708   ResPhi->addIncoming(IndexPhi, LoopStartBlock);
709   ResPhi->addIncoming(MaxLen, SVELoopIncBlock);
710   ResPhi->addIncoming(SVELoopRes, SVELoopMismatchBlock);
711 
712   Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType);
713 
714   if (VerifyLoops) {
715     ScalarLoop->verifyLoop();
716     SVELoop->verifyLoop();
717     if (!SVELoop->isRecursivelyLCSSAForm(*DT, *LI))
718       report_fatal_error("Loops must remain in LCSSA form!");
719     if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI))
720       report_fatal_error("Loops must remain in LCSSA form!");
721   }
722 
723   return FinalRes;
724 }
725 
726 void AArch64LoopIdiomTransform::transformByteCompare(
727     GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, PHINode *IndPhi,
728     Value *MaxLen, Instruction *Index, Value *Start, bool IncIdx,
729     BasicBlock *FoundBB, BasicBlock *EndBB) {
730 
731   // Insert the byte compare code at the end of the preheader block
732   BasicBlock *Preheader = CurLoop->getLoopPreheader();
733   BasicBlock *Header = CurLoop->getHeader();
734   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
735   IRBuilder<> Builder(PHBranch);
736   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
737   Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
738 
739   // Increment the pointer if this was done before the loads in the loop.
740   if (IncIdx)
741     Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1));
742 
743   Value *ByteCmpRes =
744       expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
745 
746   // Replaces uses of index & induction Phi with intrinsic (we already
747   // checked that the the first instruction of Header is the Phi above).
748   assert(IndPhi->hasOneUse() && "Index phi node has more than one use!");
749   Index->replaceAllUsesWith(ByteCmpRes);
750 
751   assert(PHBranch->isUnconditional() &&
752          "Expected preheader to terminate with an unconditional branch.");
753 
754   // If no mismatch was found, we can jump to the end block. Create a
755   // new basic block for the compare instruction.
756   auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare",
757                                    Preheader->getParent());
758   CmpBB->moveBefore(EndBB);
759 
760   // Replace the branch in the preheader with an always-true conditional branch.
761   // This ensures there is still a reference to the original loop.
762   Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header);
763   PHBranch->eraseFromParent();
764 
765   BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent();
766   DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}});
767 
768   // Create the branch to either the end or found block depending on the value
769   // returned by the intrinsic.
770   Builder.SetInsertPoint(CmpBB);
771   if (FoundBB != EndBB) {
772     Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen);
773     Builder.CreateCondBr(FoundCmp, EndBB, FoundBB);
774     DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB},
775                       {DominatorTree::Insert, CmpBB, EndBB}});
776 
777   } else {
778     Builder.CreateBr(FoundBB);
779     DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}});
780   }
781 
782   auto fixSuccessorPhis = [&](BasicBlock *SuccBB) {
783     for (PHINode &PN : SuccBB->phis()) {
784       // At this point we've already replaced all uses of the result from the
785       // loop with ByteCmp. Look through the incoming values to find ByteCmp,
786       // meaning this is a Phi collecting the results of the byte compare.
787       bool ResPhi = false;
788       for (Value *Op : PN.incoming_values())
789         if (Op == ByteCmpRes) {
790           ResPhi = true;
791           break;
792         }
793 
794       // Any PHI that depended upon the result of the byte compare needs a new
795       // incoming value from CmpBB. This is because the original loop will get
796       // deleted.
797       if (ResPhi)
798         PN.addIncoming(ByteCmpRes, CmpBB);
799       else {
800         // There should be no other outside uses of other values in the
801         // original loop. Any incoming values should either:
802         //   1. Be for blocks outside the loop, which aren't interesting. Or ..
803         //   2. These are from blocks in the loop with values defined outside
804         //      the loop. We should a similar incoming value from CmpBB.
805         for (BasicBlock *BB : PN.blocks())
806           if (CurLoop->contains(BB)) {
807             PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB);
808             break;
809           }
810       }
811     }
812   };
813 
814   // Ensure all Phis in the successors of CmpBB have an incoming value from it.
815   fixSuccessorPhis(EndBB);
816   if (EndBB != FoundBB)
817     fixSuccessorPhis(FoundBB);
818 
819   // The new CmpBB block isn't part of the loop, but will need to be added to
820   // the outer loop if there is one.
821   if (!CurLoop->isOutermost())
822     CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI);
823 
824   if (VerifyLoops && CurLoop->getParentLoop()) {
825     CurLoop->getParentLoop()->verifyLoop();
826     if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))
827       report_fatal_error("Loops must remain in LCSSA form!");
828   }
829 }
830