1 //===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements a pass that recognizes certain loop idioms and 10 // transforms them into more optimized versions of the same loop. In cases 11 // where this happens, it can be a significant performance win. 12 // 13 // We currently only recognize one loop that finds the first mismatched byte 14 // in an array and returns the index, i.e. something like: 15 // 16 // while (++i != n) { 17 // if (a[i] != b[i]) 18 // break; 19 // } 20 // 21 // In this example we can actually vectorize the loop despite the early exit, 22 // although the loop vectorizer does not support it. It requires some extra 23 // checks to deal with the possibility of faulting loads when crossing page 24 // boundaries. However, even with these checks it is still profitable to do the 25 // transformation. 26 // 27 //===----------------------------------------------------------------------===// 28 // 29 // TODO List: 30 // 31 // * Add support for the inverse case where we scan for a matching element. 32 // * Permit 64-bit induction variable types. 33 // * Recognize loops that increment the IV *after* comparing bytes. 34 // * Allow 32-bit sign-extends of the IV used by the GEP. 35 // 36 //===----------------------------------------------------------------------===// 37 38 #include "AArch64LoopIdiomTransform.h" 39 #include "llvm/Analysis/DomTreeUpdater.h" 40 #include "llvm/Analysis/LoopPass.h" 41 #include "llvm/Analysis/TargetTransformInfo.h" 42 #include "llvm/IR/Dominators.h" 43 #include "llvm/IR/IRBuilder.h" 44 #include "llvm/IR/Intrinsics.h" 45 #include "llvm/IR/MDBuilder.h" 46 #include "llvm/IR/PatternMatch.h" 47 #include "llvm/InitializePasses.h" 48 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 49 50 using namespace llvm; 51 using namespace PatternMatch; 52 53 #define DEBUG_TYPE "aarch64-loop-idiom-transform" 54 55 static cl::opt<bool> 56 DisableAll("disable-aarch64-lit-all", cl::Hidden, cl::init(false), 57 cl::desc("Disable AArch64 Loop Idiom Transform Pass.")); 58 59 static cl::opt<bool> DisableByteCmp( 60 "disable-aarch64-lit-bytecmp", cl::Hidden, cl::init(false), 61 cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do " 62 "not convert byte-compare loop(s).")); 63 64 static cl::opt<bool> VerifyLoops( 65 "aarch64-lit-verify", cl::Hidden, cl::init(false), 66 cl::desc("Verify loops generated AArch64 Loop Idiom Transform Pass.")); 67 68 namespace llvm { 69 70 void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &); 71 Pass *createAArch64LoopIdiomTransformPass(); 72 73 } // end namespace llvm 74 75 namespace { 76 77 class AArch64LoopIdiomTransform { 78 Loop *CurLoop = nullptr; 79 DominatorTree *DT; 80 LoopInfo *LI; 81 const TargetTransformInfo *TTI; 82 const DataLayout *DL; 83 84 public: 85 explicit AArch64LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI, 86 const TargetTransformInfo *TTI, 87 const DataLayout *DL) 88 : DT(DT), LI(LI), TTI(TTI), DL(DL) {} 89 90 bool run(Loop *L); 91 92 private: 93 /// \name Countable Loop Idiom Handling 94 /// @{ 95 96 bool runOnCountableLoop(); 97 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, 98 SmallVectorImpl<BasicBlock *> &ExitBlocks); 99 100 bool recognizeByteCompare(); 101 Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 102 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 103 Instruction *Index, Value *Start, Value *MaxLen); 104 void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 105 PHINode *IndPhi, Value *MaxLen, Instruction *Index, 106 Value *Start, bool IncIdx, BasicBlock *FoundBB, 107 BasicBlock *EndBB); 108 /// @} 109 }; 110 111 class AArch64LoopIdiomTransformLegacyPass : public LoopPass { 112 public: 113 static char ID; 114 115 explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID) { 116 initializeAArch64LoopIdiomTransformLegacyPassPass( 117 *PassRegistry::getPassRegistry()); 118 } 119 120 StringRef getPassName() const override { 121 return "Transform AArch64-specific loop idioms"; 122 } 123 124 void getAnalysisUsage(AnalysisUsage &AU) const override { 125 AU.addRequired<LoopInfoWrapperPass>(); 126 AU.addRequired<DominatorTreeWrapperPass>(); 127 AU.addRequired<TargetTransformInfoWrapperPass>(); 128 } 129 130 bool runOnLoop(Loop *L, LPPassManager &LPM) override; 131 }; 132 133 bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop *L, 134 LPPassManager &LPM) { 135 136 if (skipLoop(L)) 137 return false; 138 139 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 140 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 141 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( 142 *L->getHeader()->getParent()); 143 return AArch64LoopIdiomTransform( 144 DT, LI, &TTI, &L->getHeader()->getModule()->getDataLayout()) 145 .run(L); 146 } 147 148 } // end anonymous namespace 149 150 char AArch64LoopIdiomTransformLegacyPass::ID = 0; 151 152 INITIALIZE_PASS_BEGIN( 153 AArch64LoopIdiomTransformLegacyPass, "aarch64-lit", 154 "Transform specific loop idioms into optimized vector forms", false, false) 155 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 156 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 157 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) 158 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 159 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 160 INITIALIZE_PASS_END( 161 AArch64LoopIdiomTransformLegacyPass, "aarch64-lit", 162 "Transform specific loop idioms into optimized vector forms", false, false) 163 164 Pass *llvm::createAArch64LoopIdiomTransformPass() { 165 return new AArch64LoopIdiomTransformLegacyPass(); 166 } 167 168 PreservedAnalyses 169 AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM, 170 LoopStandardAnalysisResults &AR, 171 LPMUpdater &) { 172 if (DisableAll) 173 return PreservedAnalyses::all(); 174 175 const auto *DL = &L.getHeader()->getModule()->getDataLayout(); 176 177 AArch64LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL); 178 if (!LIT.run(&L)) 179 return PreservedAnalyses::all(); 180 181 return PreservedAnalyses::none(); 182 } 183 184 //===----------------------------------------------------------------------===// 185 // 186 // Implementation of AArch64LoopIdiomTransform 187 // 188 //===----------------------------------------------------------------------===// 189 190 bool AArch64LoopIdiomTransform::run(Loop *L) { 191 CurLoop = L; 192 193 if (DisableAll || L->getHeader()->getParent()->hasOptSize()) 194 return false; 195 196 // If the loop could not be converted to canonical form, it must have an 197 // indirectbr in it, just give up. 198 if (!L->getLoopPreheader()) 199 return false; 200 201 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" 202 << CurLoop->getHeader()->getParent()->getName() 203 << "] Loop %" << CurLoop->getHeader()->getName() << "\n"); 204 205 return recognizeByteCompare(); 206 } 207 208 bool AArch64LoopIdiomTransform::recognizeByteCompare() { 209 // Currently the transformation only works on scalable vector types, although 210 // there is no fundamental reason why it cannot be made to work for fixed 211 // width too. 212 213 // We also need to know the minimum page size for the target in order to 214 // generate runtime memory checks to ensure the vector version won't fault. 215 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || 216 DisableByteCmp) 217 return false; 218 219 BasicBlock *Header = CurLoop->getHeader(); 220 221 // In AArch64LoopIdiomTransform::run we have already checked that the loop 222 // has a preheader so we can assume it's in a canonical form. 223 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) 224 return false; 225 226 PHINode *PN = dyn_cast<PHINode>(&Header->front()); 227 if (!PN || PN->getNumIncomingValues() != 2) 228 return false; 229 230 auto LoopBlocks = CurLoop->getBlocks(); 231 // The first block in the loop should contain only 4 instructions, e.g. 232 // 233 // while.cond: 234 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] 235 // %inc = add i32 %res.phi, 1 236 // %cmp.not = icmp eq i32 %inc, %n 237 // br i1 %cmp.not, label %while.end, label %while.body 238 // 239 auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug(); 240 if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4) 241 return false; 242 243 // The second block should contain 7 instructions, e.g. 244 // 245 // while.body: 246 // %idx = zext i32 %inc to i64 247 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx 248 // %load.a = load i8, ptr %idx.a 249 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx 250 // %load.b = load i8, ptr %idx.b 251 // %cmp.not.ld = icmp eq i8 %load.a, %load.b 252 // br i1 %cmp.not.ld, label %while.cond, label %while.end 253 // 254 auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug(); 255 if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7) 256 return false; 257 258 // The incoming value to the PHI node from the loop should be an add of 1. 259 Value *StartIdx = nullptr; 260 Instruction *Index = nullptr; 261 if (!CurLoop->contains(PN->getIncomingBlock(0))) { 262 StartIdx = PN->getIncomingValue(0); 263 Index = dyn_cast<Instruction>(PN->getIncomingValue(1)); 264 } else { 265 StartIdx = PN->getIncomingValue(1); 266 Index = dyn_cast<Instruction>(PN->getIncomingValue(0)); 267 } 268 269 // Limit to 32-bit types for now 270 if (!Index || !Index->getType()->isIntegerTy(32) || 271 !match(Index, m_c_Add(m_Specific(PN), m_One()))) 272 return false; 273 274 // If we match the pattern, PN and Index will be replaced with the result of 275 // the cttz.elts intrinsic. If any other instructions are used outside of 276 // the loop, we cannot replace it. 277 for (BasicBlock *BB : LoopBlocks) 278 for (Instruction &I : *BB) 279 if (&I != PN && &I != Index) 280 for (User *U : I.users()) 281 if (!CurLoop->contains(cast<Instruction>(U))) 282 return false; 283 284 // Match the branch instruction for the header 285 ICmpInst::Predicate Pred; 286 Value *MaxLen; 287 BasicBlock *EndBB, *WhileBB; 288 if (!match(Header->getTerminator(), 289 m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)), 290 m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) || 291 Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB)) 292 return false; 293 294 // WhileBB should contain the pattern of load & compare instructions. Match 295 // the pattern and find the GEP instructions used by the loads. 296 ICmpInst::Predicate WhilePred; 297 BasicBlock *FoundBB; 298 BasicBlock *TrueBB; 299 Value *LoadA, *LoadB; 300 if (!match(WhileBB->getTerminator(), 301 m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)), 302 m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) || 303 WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB)) 304 return false; 305 306 Value *A, *B; 307 if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B)))) 308 return false; 309 310 LoadInst *LoadAI = cast<LoadInst>(LoadA); 311 LoadInst *LoadBI = cast<LoadInst>(LoadB); 312 if (!LoadAI->isSimple() || !LoadBI->isSimple()) 313 return false; 314 315 GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A); 316 GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B); 317 318 if (!GEPA || !GEPB) 319 return false; 320 321 Value *PtrA = GEPA->getPointerOperand(); 322 Value *PtrB = GEPB->getPointerOperand(); 323 324 // Check we are loading i8 values from two loop invariant pointers 325 if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) || 326 !GEPA->getResultElementType()->isIntegerTy(8) || 327 !GEPB->getResultElementType()->isIntegerTy(8) || 328 !LoadAI->getType()->isIntegerTy(8) || 329 !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB) 330 return false; 331 332 // Check that the index to the GEPs is the index we found earlier 333 if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) 334 return false; 335 336 Value *IdxA = GEPA->getOperand(GEPA->getNumIndices()); 337 Value *IdxB = GEPB->getOperand(GEPB->getNumIndices()); 338 if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index)))) 339 return false; 340 341 // We only ever expect the pre-incremented index value to be used inside the 342 // loop. 343 if (!PN->hasOneUse()) 344 return false; 345 346 // Ensure that when the Found and End blocks are identical the PHIs have the 347 // supported format. We don't currently allow cases like this: 348 // while.cond: 349 // ... 350 // br i1 %cmp.not, label %while.end, label %while.body 351 // 352 // while.body: 353 // ... 354 // br i1 %cmp.not2, label %while.cond, label %while.end 355 // 356 // while.end: 357 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] 358 // 359 // Where the incoming values for %final_ptr are unique and from each of the 360 // loop blocks, but not actually defined in the loop. This requires extra 361 // work setting up the byte.compare block, i.e. by introducing a select to 362 // choose the correct value. 363 // TODO: We could add support for this in future. 364 if (FoundBB == EndBB) { 365 for (PHINode &EndPN : EndBB->phis()) { 366 Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header); 367 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB); 368 369 // The value of the index when leaving the while.cond block is always the 370 // same as the end value (MaxLen) so we permit either. The value when 371 // leaving the while.body block should only be the index. Otherwise for 372 // any other values we only allow ones that are same for both blocks. 373 if (WhileCondVal != WhileBodyVal && 374 ((WhileCondVal != Index && WhileCondVal != MaxLen) || 375 (WhileBodyVal != Index))) 376 return false; 377 } 378 } 379 380 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" 381 << *(EndBB->getParent()) << "\n\n"); 382 383 // The index is incremented before the GEP/Load pair so we need to 384 // add 1 to the start value. 385 transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true, 386 FoundBB, EndBB); 387 return true; 388 } 389 390 Value *AArch64LoopIdiomTransform::expandFindMismatch( 391 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 392 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { 393 Value *PtrA = GEPA->getPointerOperand(); 394 Value *PtrB = GEPB->getPointerOperand(); 395 396 // Get the arguments and types for the intrinsic. 397 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 398 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 399 LLVMContext &Ctx = PHBranch->getContext(); 400 Type *LoadType = Type::getInt8Ty(Ctx); 401 Type *ResType = Builder.getInt32Ty(); 402 403 // Split block in the original loop preheader. 404 BasicBlock *EndBlock = 405 SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end"); 406 407 // Create the blocks that we're going to need: 408 // 1. A block for checking the zero-extended length exceeds 0 409 // 2. A block to check that the start and end addresses of a given array 410 // lie on the same page. 411 // 3. The SVE loop preheader. 412 // 4. The first SVE loop block. 413 // 5. The SVE loop increment block. 414 // 6. A block we can jump to from the SVE loop when a mismatch is found. 415 // 7. The first block of the scalar loop itself, containing PHIs , loads 416 // and cmp. 417 // 8. A scalar loop increment block to increment the PHIs and go back 418 // around the loop. 419 420 BasicBlock *MinItCheckBlock = BasicBlock::Create( 421 Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock); 422 423 // Update the terminator added by SplitBlock to branch to the first block 424 Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock); 425 426 BasicBlock *MemCheckBlock = BasicBlock::Create( 427 Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock); 428 429 BasicBlock *SVELoopPreheaderBlock = BasicBlock::Create( 430 Ctx, "mismatch_sve_loop_preheader", EndBlock->getParent(), EndBlock); 431 432 BasicBlock *SVELoopStartBlock = BasicBlock::Create( 433 Ctx, "mismatch_sve_loop", EndBlock->getParent(), EndBlock); 434 435 BasicBlock *SVELoopIncBlock = BasicBlock::Create( 436 Ctx, "mismatch_sve_loop_inc", EndBlock->getParent(), EndBlock); 437 438 BasicBlock *SVELoopMismatchBlock = BasicBlock::Create( 439 Ctx, "mismatch_sve_loop_found", EndBlock->getParent(), EndBlock); 440 441 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create( 442 Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock); 443 444 BasicBlock *LoopStartBlock = 445 BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock); 446 447 BasicBlock *LoopIncBlock = BasicBlock::Create( 448 Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock); 449 450 DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock}, 451 {DominatorTree::Delete, Preheader, EndBlock}}); 452 453 // Update LoopInfo with the new SVE & scalar loops. 454 auto SVELoop = LI->AllocateLoop(); 455 auto ScalarLoop = LI->AllocateLoop(); 456 457 if (CurLoop->getParentLoop()) { 458 CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI); 459 CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI); 460 CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopPreheaderBlock, *LI); 461 CurLoop->getParentLoop()->addChildLoop(SVELoop); 462 CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopMismatchBlock, *LI); 463 CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI); 464 CurLoop->getParentLoop()->addChildLoop(ScalarLoop); 465 } else { 466 LI->addTopLevelLoop(SVELoop); 467 LI->addTopLevelLoop(ScalarLoop); 468 } 469 470 // Add the new basic blocks to their associated loops. 471 SVELoop->addBasicBlockToLoop(SVELoopStartBlock, *LI); 472 SVELoop->addBasicBlockToLoop(SVELoopIncBlock, *LI); 473 474 ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI); 475 ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI); 476 477 // Set up some types and constants that we intend to reuse. 478 Type *I64Type = Builder.getInt64Ty(); 479 480 // Check the zero-extended iteration count > 0 481 Builder.SetInsertPoint(MinItCheckBlock); 482 Value *ExtStart = Builder.CreateZExt(Start, I64Type); 483 Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type); 484 // This check doesn't really cost us very much. 485 486 Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen); 487 BranchInst *MinItCheckBr = 488 BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck); 489 MinItCheckBr->setMetadata( 490 LLVMContext::MD_prof, 491 MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1)); 492 Builder.Insert(MinItCheckBr); 493 494 DTU.applyUpdates( 495 {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, 496 {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); 497 498 // For each of the arrays, check the start/end addresses are on the same 499 // page. 500 Builder.SetInsertPoint(MemCheckBlock); 501 502 // The early exit in the original loop means that when performing vector 503 // loads we are potentially reading ahead of the early exit. So we could 504 // fault if crossing a page boundary. Therefore, we create runtime memory 505 // checks based on the minimum page size as follows: 506 // 1. Calculate the addresses of the first memory accesses in the loop, 507 // i.e. LhsStart and RhsStart. 508 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. 509 // 3. Determine which pages correspond to all the memory accesses, i.e 510 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. 511 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then 512 // we know we won't cross any page boundaries in the loop so we can 513 // enter the vector loop! Otherwise we fall back on the scalar loop. 514 Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart); 515 Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart); 516 Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type); 517 Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type); 518 Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd); 519 Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd); 520 Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type); 521 Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type); 522 523 const uint64_t MinPageSize = TTI->getMinPageSize().value(); 524 const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize); 525 Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt); 526 Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt); 527 Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt); 528 Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt); 529 Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage); 530 Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage); 531 532 Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp); 533 BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( 534 LoopPreHeaderBlock, SVELoopPreheaderBlock, CombinedPageCmp); 535 CombinedPageCmpCmpBr->setMetadata( 536 LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext()) 537 .createBranchWeights(10, 90)); 538 Builder.Insert(CombinedPageCmpCmpBr); 539 540 DTU.applyUpdates( 541 {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, 542 {DominatorTree::Insert, MemCheckBlock, SVELoopPreheaderBlock}}); 543 544 // Set up the SVE loop preheader, i.e. calculate initial loop predicate, 545 // zero-extend MaxLen to 64-bits, determine the number of vector elements 546 // processed in each iteration, etc. 547 Builder.SetInsertPoint(SVELoopPreheaderBlock); 548 549 // At this point we know two things must be true: 550 // 1. Start <= End 551 // 2. ExtMaxLen <= MinPageSize due to the page checks. 552 // Therefore, we know that we can use a 64-bit induction variable that 553 // starts from 0 -> ExtMaxLen and it will not overflow. 554 ScalableVectorType *PredVTy = 555 ScalableVectorType::get(Builder.getInt1Ty(), 16); 556 557 Value *InitialPred = Builder.CreateIntrinsic( 558 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); 559 560 Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); 561 VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "", 562 /*HasNUW=*/true, /*HasNSW=*/true); 563 564 Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), 565 Builder.getInt1(false)); 566 567 BranchInst *JumpToSVELoop = BranchInst::Create(SVELoopStartBlock); 568 Builder.Insert(JumpToSVELoop); 569 570 DTU.applyUpdates( 571 {{DominatorTree::Insert, SVELoopPreheaderBlock, SVELoopStartBlock}}); 572 573 // Set up the first SVE loop block by creating the PHIs, doing the vector 574 // loads and comparing the vectors. 575 Builder.SetInsertPoint(SVELoopStartBlock); 576 PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_sve_loop_pred"); 577 LoopPred->addIncoming(InitialPred, SVELoopPreheaderBlock); 578 PHINode *SVEIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_sve_index"); 579 SVEIndexPhi->addIncoming(ExtStart, SVELoopPreheaderBlock); 580 Type *SVELoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16); 581 Value *Passthru = ConstantInt::getNullValue(SVELoadType); 582 583 Value *SVELhsGep = Builder.CreateGEP(LoadType, PtrA, SVEIndexPhi); 584 if (GEPA->isInBounds()) 585 cast<GetElementPtrInst>(SVELhsGep)->setIsInBounds(true); 586 Value *SVELhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVELhsGep, Align(1), 587 LoopPred, Passthru); 588 589 Value *SVERhsGep = Builder.CreateGEP(LoadType, PtrB, SVEIndexPhi); 590 if (GEPB->isInBounds()) 591 cast<GetElementPtrInst>(SVERhsGep)->setIsInBounds(true); 592 Value *SVERhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVERhsGep, Align(1), 593 LoopPred, Passthru); 594 595 Value *SVEMatchCmp = Builder.CreateICmpNE(SVELhsLoad, SVERhsLoad); 596 SVEMatchCmp = Builder.CreateSelect(LoopPred, SVEMatchCmp, PFalse); 597 Value *SVEMatchHasActiveLanes = Builder.CreateOrReduce(SVEMatchCmp); 598 BranchInst *SVEEarlyExit = BranchInst::Create( 599 SVELoopMismatchBlock, SVELoopIncBlock, SVEMatchHasActiveLanes); 600 Builder.Insert(SVEEarlyExit); 601 602 DTU.applyUpdates( 603 {{DominatorTree::Insert, SVELoopStartBlock, SVELoopMismatchBlock}, 604 {DominatorTree::Insert, SVELoopStartBlock, SVELoopIncBlock}}); 605 606 // Increment the index counter and calculate the predicate for the next 607 // iteration of the loop. We branch back to the start of the loop if there 608 // is at least one active lane. 609 Builder.SetInsertPoint(SVELoopIncBlock); 610 Value *NewSVEIndexPhi = Builder.CreateAdd(SVEIndexPhi, VecLen, "", 611 /*HasNUW=*/true, /*HasNSW=*/true); 612 SVEIndexPhi->addIncoming(NewSVEIndexPhi, SVELoopIncBlock); 613 Value *NewPred = 614 Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, 615 {PredVTy, I64Type}, {NewSVEIndexPhi, ExtEnd}); 616 LoopPred->addIncoming(NewPred, SVELoopIncBlock); 617 618 Value *PredHasActiveLanes = 619 Builder.CreateExtractElement(NewPred, uint64_t(0)); 620 BranchInst *SVELoopBranchBack = 621 BranchInst::Create(SVELoopStartBlock, EndBlock, PredHasActiveLanes); 622 Builder.Insert(SVELoopBranchBack); 623 624 DTU.applyUpdates({{DominatorTree::Insert, SVELoopIncBlock, SVELoopStartBlock}, 625 {DominatorTree::Insert, SVELoopIncBlock, EndBlock}}); 626 627 // If we found a mismatch then we need to calculate which lane in the vector 628 // had a mismatch and add that on to the current loop index. 629 Builder.SetInsertPoint(SVELoopMismatchBlock); 630 PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_sve_found_pred"); 631 FoundPred->addIncoming(SVEMatchCmp, SVELoopStartBlock); 632 PHINode *LastLoopPred = 633 Builder.CreatePHI(PredVTy, 1, "mismatch_sve_last_loop_pred"); 634 LastLoopPred->addIncoming(LoopPred, SVELoopStartBlock); 635 PHINode *SVEFoundIndex = 636 Builder.CreatePHI(I64Type, 1, "mismatch_sve_found_index"); 637 SVEFoundIndex->addIncoming(SVEIndexPhi, SVELoopStartBlock); 638 639 Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred); 640 Value *Ctz = Builder.CreateIntrinsic( 641 Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()}, 642 {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)}); 643 Ctz = Builder.CreateZExt(Ctz, I64Type); 644 Value *SVELoopRes64 = Builder.CreateAdd(SVEFoundIndex, Ctz, "", 645 /*HasNUW=*/true, /*HasNSW=*/true); 646 Value *SVELoopRes = Builder.CreateTrunc(SVELoopRes64, ResType); 647 648 Builder.Insert(BranchInst::Create(EndBlock)); 649 650 DTU.applyUpdates({{DominatorTree::Insert, SVELoopMismatchBlock, EndBlock}}); 651 652 // Generate code for scalar loop. 653 Builder.SetInsertPoint(LoopPreHeaderBlock); 654 Builder.Insert(BranchInst::Create(LoopStartBlock)); 655 656 DTU.applyUpdates( 657 {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); 658 659 Builder.SetInsertPoint(LoopStartBlock); 660 PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index"); 661 IndexPhi->addIncoming(Start, LoopPreHeaderBlock); 662 663 // Otherwise compare the values 664 // Load bytes from each array and compare them. 665 Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type); 666 667 Value *LhsGep = Builder.CreateGEP(LoadType, PtrA, GepOffset); 668 if (GEPA->isInBounds()) 669 cast<GetElementPtrInst>(LhsGep)->setIsInBounds(true); 670 Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep); 671 672 Value *RhsGep = Builder.CreateGEP(LoadType, PtrB, GepOffset); 673 if (GEPB->isInBounds()) 674 cast<GetElementPtrInst>(RhsGep)->setIsInBounds(true); 675 Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep); 676 677 Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad); 678 // If we have a mismatch then exit the loop ... 679 BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp); 680 Builder.Insert(MatchCmpBr); 681 682 DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, 683 {DominatorTree::Insert, LoopStartBlock, EndBlock}}); 684 685 // Have we reached the maximum permitted length for the loop? 686 Builder.SetInsertPoint(LoopIncBlock); 687 Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "", 688 /*HasNUW=*/Index->hasNoUnsignedWrap(), 689 /*HasNSW=*/Index->hasNoSignedWrap()); 690 IndexPhi->addIncoming(PhiInc, LoopIncBlock); 691 Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen); 692 BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp); 693 Builder.Insert(IVCmpBr); 694 695 DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock}, 696 {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); 697 698 // In the end block we need to insert a PHI node to deal with three cases: 699 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. 700 // 2. We exitted the scalar loop early due to a mismatch and need to return 701 // the index that we found. 702 // 3. We didn't find a mismatch in the SVE loop, so we return MaxLen. 703 // 4. We exitted the SVE loop early due to a mismatch and need to return 704 // the index that we found. 705 Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt()); 706 PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result"); 707 ResPhi->addIncoming(MaxLen, LoopIncBlock); 708 ResPhi->addIncoming(IndexPhi, LoopStartBlock); 709 ResPhi->addIncoming(MaxLen, SVELoopIncBlock); 710 ResPhi->addIncoming(SVELoopRes, SVELoopMismatchBlock); 711 712 Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType); 713 714 if (VerifyLoops) { 715 ScalarLoop->verifyLoop(); 716 SVELoop->verifyLoop(); 717 if (!SVELoop->isRecursivelyLCSSAForm(*DT, *LI)) 718 report_fatal_error("Loops must remain in LCSSA form!"); 719 if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI)) 720 report_fatal_error("Loops must remain in LCSSA form!"); 721 } 722 723 return FinalRes; 724 } 725 726 void AArch64LoopIdiomTransform::transformByteCompare( 727 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, PHINode *IndPhi, 728 Value *MaxLen, Instruction *Index, Value *Start, bool IncIdx, 729 BasicBlock *FoundBB, BasicBlock *EndBB) { 730 731 // Insert the byte compare code at the end of the preheader block 732 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 733 BasicBlock *Header = CurLoop->getHeader(); 734 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 735 IRBuilder<> Builder(PHBranch); 736 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 737 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); 738 739 // Increment the pointer if this was done before the loads in the loop. 740 if (IncIdx) 741 Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1)); 742 743 Value *ByteCmpRes = 744 expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen); 745 746 // Replaces uses of index & induction Phi with intrinsic (we already 747 // checked that the the first instruction of Header is the Phi above). 748 assert(IndPhi->hasOneUse() && "Index phi node has more than one use!"); 749 Index->replaceAllUsesWith(ByteCmpRes); 750 751 assert(PHBranch->isUnconditional() && 752 "Expected preheader to terminate with an unconditional branch."); 753 754 // If no mismatch was found, we can jump to the end block. Create a 755 // new basic block for the compare instruction. 756 auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare", 757 Preheader->getParent()); 758 CmpBB->moveBefore(EndBB); 759 760 // Replace the branch in the preheader with an always-true conditional branch. 761 // This ensures there is still a reference to the original loop. 762 Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header); 763 PHBranch->eraseFromParent(); 764 765 BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent(); 766 DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}}); 767 768 // Create the branch to either the end or found block depending on the value 769 // returned by the intrinsic. 770 Builder.SetInsertPoint(CmpBB); 771 if (FoundBB != EndBB) { 772 Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen); 773 Builder.CreateCondBr(FoundCmp, EndBB, FoundBB); 774 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}, 775 {DominatorTree::Insert, CmpBB, EndBB}}); 776 777 } else { 778 Builder.CreateBr(FoundBB); 779 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}}); 780 } 781 782 auto fixSuccessorPhis = [&](BasicBlock *SuccBB) { 783 for (PHINode &PN : SuccBB->phis()) { 784 // At this point we've already replaced all uses of the result from the 785 // loop with ByteCmp. Look through the incoming values to find ByteCmp, 786 // meaning this is a Phi collecting the results of the byte compare. 787 bool ResPhi = false; 788 for (Value *Op : PN.incoming_values()) 789 if (Op == ByteCmpRes) { 790 ResPhi = true; 791 break; 792 } 793 794 // Any PHI that depended upon the result of the byte compare needs a new 795 // incoming value from CmpBB. This is because the original loop will get 796 // deleted. 797 if (ResPhi) 798 PN.addIncoming(ByteCmpRes, CmpBB); 799 else { 800 // There should be no other outside uses of other values in the 801 // original loop. Any incoming values should either: 802 // 1. Be for blocks outside the loop, which aren't interesting. Or .. 803 // 2. These are from blocks in the loop with values defined outside 804 // the loop. We should a similar incoming value from CmpBB. 805 for (BasicBlock *BB : PN.blocks()) 806 if (CurLoop->contains(BB)) { 807 PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB); 808 break; 809 } 810 } 811 } 812 }; 813 814 // Ensure all Phis in the successors of CmpBB have an incoming value from it. 815 fixSuccessorPhis(EndBB); 816 if (EndBB != FoundBB) 817 fixSuccessorPhis(FoundBB); 818 819 // The new CmpBB block isn't part of the loop, but will need to be added to 820 // the outer loop if there is one. 821 if (!CurLoop->isOutermost()) 822 CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI); 823 824 if (VerifyLoops && CurLoop->getParentLoop()) { 825 CurLoop->getParentLoop()->verifyLoop(); 826 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI)) 827 report_fatal_error("Loops must remain in LCSSA form!"); 828 } 829 } 830