1 //=== lib/CodeGen/GlobalISel/AArch64PreLegalizerCombiner.cpp --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass does combining of machine instructions at the generic MI level, 10 // before the legalizer. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "AArch64GlobalISelUtils.h" 15 #include "AArch64TargetMachine.h" 16 #include "llvm/CodeGen/GlobalISel/CSEInfo.h" 17 #include "llvm/CodeGen/GlobalISel/Combiner.h" 18 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 19 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h" 20 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" 21 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" 22 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 23 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 24 #include "llvm/CodeGen/MachineDominators.h" 25 #include "llvm/CodeGen/MachineFunction.h" 26 #include "llvm/CodeGen/MachineFunctionPass.h" 27 #include "llvm/CodeGen/MachineRegisterInfo.h" 28 #include "llvm/CodeGen/TargetPassConfig.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/Support/Debug.h" 31 32 #define GET_GICOMBINER_DEPS 33 #include "AArch64GenPreLegalizeGICombiner.inc" 34 #undef GET_GICOMBINER_DEPS 35 36 #define DEBUG_TYPE "aarch64-prelegalizer-combiner" 37 38 using namespace llvm; 39 using namespace MIPatternMatch; 40 41 namespace { 42 43 #define GET_GICOMBINER_TYPES 44 #include "AArch64GenPreLegalizeGICombiner.inc" 45 #undef GET_GICOMBINER_TYPES 46 47 /// Return true if a G_FCONSTANT instruction is known to be better-represented 48 /// as a G_CONSTANT. 49 bool matchFConstantToConstant(MachineInstr &MI, MachineRegisterInfo &MRI) { 50 assert(MI.getOpcode() == TargetOpcode::G_FCONSTANT); 51 Register DstReg = MI.getOperand(0).getReg(); 52 const unsigned DstSize = MRI.getType(DstReg).getSizeInBits(); 53 if (DstSize != 32 && DstSize != 64) 54 return false; 55 56 // When we're storing a value, it doesn't matter what register bank it's on. 57 // Since not all floating point constants can be materialized using a fmov, 58 // it makes more sense to just use a GPR. 59 return all_of(MRI.use_nodbg_instructions(DstReg), 60 [](const MachineInstr &Use) { return Use.mayStore(); }); 61 } 62 63 /// Change a G_FCONSTANT into a G_CONSTANT. 64 void applyFConstantToConstant(MachineInstr &MI) { 65 assert(MI.getOpcode() == TargetOpcode::G_FCONSTANT); 66 MachineIRBuilder MIB(MI); 67 const APFloat &ImmValAPF = MI.getOperand(1).getFPImm()->getValueAPF(); 68 MIB.buildConstant(MI.getOperand(0).getReg(), ImmValAPF.bitcastToAPInt()); 69 MI.eraseFromParent(); 70 } 71 72 /// Try to match a G_ICMP of a G_TRUNC with zero, in which the truncated bits 73 /// are sign bits. In this case, we can transform the G_ICMP to directly compare 74 /// the wide value with a zero. 75 bool matchICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI, 76 GISelKnownBits *KB, Register &MatchInfo) { 77 assert(MI.getOpcode() == TargetOpcode::G_ICMP && KB); 78 79 auto Pred = (CmpInst::Predicate)MI.getOperand(1).getPredicate(); 80 if (!ICmpInst::isEquality(Pred)) 81 return false; 82 83 Register LHS = MI.getOperand(2).getReg(); 84 LLT LHSTy = MRI.getType(LHS); 85 if (!LHSTy.isScalar()) 86 return false; 87 88 Register RHS = MI.getOperand(3).getReg(); 89 Register WideReg; 90 91 if (!mi_match(LHS, MRI, m_GTrunc(m_Reg(WideReg))) || 92 !mi_match(RHS, MRI, m_SpecificICst(0))) 93 return false; 94 95 LLT WideTy = MRI.getType(WideReg); 96 if (KB->computeNumSignBits(WideReg) <= 97 WideTy.getSizeInBits() - LHSTy.getSizeInBits()) 98 return false; 99 100 MatchInfo = WideReg; 101 return true; 102 } 103 104 void applyICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI, 105 MachineIRBuilder &Builder, 106 GISelChangeObserver &Observer, Register &WideReg) { 107 assert(MI.getOpcode() == TargetOpcode::G_ICMP); 108 109 LLT WideTy = MRI.getType(WideReg); 110 // We're going to directly use the wide register as the LHS, and then use an 111 // equivalent size zero for RHS. 112 Builder.setInstrAndDebugLoc(MI); 113 auto WideZero = Builder.buildConstant(WideTy, 0); 114 Observer.changingInstr(MI); 115 MI.getOperand(2).setReg(WideReg); 116 MI.getOperand(3).setReg(WideZero.getReg(0)); 117 Observer.changedInstr(MI); 118 } 119 120 /// \returns true if it is possible to fold a constant into a G_GLOBAL_VALUE. 121 /// 122 /// e.g. 123 /// 124 /// %g = G_GLOBAL_VALUE @x -> %g = G_GLOBAL_VALUE @x + cst 125 bool matchFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI, 126 std::pair<uint64_t, uint64_t> &MatchInfo) { 127 assert(MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE); 128 MachineFunction &MF = *MI.getMF(); 129 auto &GlobalOp = MI.getOperand(1); 130 auto *GV = GlobalOp.getGlobal(); 131 if (GV->isThreadLocal()) 132 return false; 133 134 // Don't allow anything that could represent offsets etc. 135 if (MF.getSubtarget<AArch64Subtarget>().ClassifyGlobalReference( 136 GV, MF.getTarget()) != AArch64II::MO_NO_FLAG) 137 return false; 138 139 // Look for a G_GLOBAL_VALUE only used by G_PTR_ADDs against constants: 140 // 141 // %g = G_GLOBAL_VALUE @x 142 // %ptr1 = G_PTR_ADD %g, cst1 143 // %ptr2 = G_PTR_ADD %g, cst2 144 // ... 145 // %ptrN = G_PTR_ADD %g, cstN 146 // 147 // Identify the *smallest* constant. We want to be able to form this: 148 // 149 // %offset_g = G_GLOBAL_VALUE @x + min_cst 150 // %g = G_PTR_ADD %offset_g, -min_cst 151 // %ptr1 = G_PTR_ADD %g, cst1 152 // ... 153 Register Dst = MI.getOperand(0).getReg(); 154 uint64_t MinOffset = -1ull; 155 for (auto &UseInstr : MRI.use_nodbg_instructions(Dst)) { 156 if (UseInstr.getOpcode() != TargetOpcode::G_PTR_ADD) 157 return false; 158 auto Cst = getIConstantVRegValWithLookThrough( 159 UseInstr.getOperand(2).getReg(), MRI); 160 if (!Cst) 161 return false; 162 MinOffset = std::min(MinOffset, Cst->Value.getZExtValue()); 163 } 164 165 // Require that the new offset is larger than the existing one to avoid 166 // infinite loops. 167 uint64_t CurrOffset = GlobalOp.getOffset(); 168 uint64_t NewOffset = MinOffset + CurrOffset; 169 if (NewOffset <= CurrOffset) 170 return false; 171 172 // Check whether folding this offset is legal. It must not go out of bounds of 173 // the referenced object to avoid violating the code model, and must be 174 // smaller than 2^20 because this is the largest offset expressible in all 175 // object formats. (The IMAGE_REL_ARM64_PAGEBASE_REL21 relocation in COFF 176 // stores an immediate signed 21 bit offset.) 177 // 178 // This check also prevents us from folding negative offsets, which will end 179 // up being treated in the same way as large positive ones. They could also 180 // cause code model violations, and aren't really common enough to matter. 181 if (NewOffset >= (1 << 20)) 182 return false; 183 184 Type *T = GV->getValueType(); 185 if (!T->isSized() || 186 NewOffset > GV->getParent()->getDataLayout().getTypeAllocSize(T)) 187 return false; 188 MatchInfo = std::make_pair(NewOffset, MinOffset); 189 return true; 190 } 191 192 void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI, 193 MachineIRBuilder &B, GISelChangeObserver &Observer, 194 std::pair<uint64_t, uint64_t> &MatchInfo) { 195 // Change: 196 // 197 // %g = G_GLOBAL_VALUE @x 198 // %ptr1 = G_PTR_ADD %g, cst1 199 // %ptr2 = G_PTR_ADD %g, cst2 200 // ... 201 // %ptrN = G_PTR_ADD %g, cstN 202 // 203 // To: 204 // 205 // %offset_g = G_GLOBAL_VALUE @x + min_cst 206 // %g = G_PTR_ADD %offset_g, -min_cst 207 // %ptr1 = G_PTR_ADD %g, cst1 208 // ... 209 // %ptrN = G_PTR_ADD %g, cstN 210 // 211 // Then, the original G_PTR_ADDs should be folded later on so that they look 212 // like this: 213 // 214 // %ptrN = G_PTR_ADD %offset_g, cstN - min_cst 215 uint64_t Offset, MinOffset; 216 std::tie(Offset, MinOffset) = MatchInfo; 217 B.setInstrAndDebugLoc(*std::next(MI.getIterator())); 218 Observer.changingInstr(MI); 219 auto &GlobalOp = MI.getOperand(1); 220 auto *GV = GlobalOp.getGlobal(); 221 GlobalOp.ChangeToGA(GV, Offset, GlobalOp.getTargetFlags()); 222 Register Dst = MI.getOperand(0).getReg(); 223 Register NewGVDst = MRI.cloneVirtualRegister(Dst); 224 MI.getOperand(0).setReg(NewGVDst); 225 Observer.changedInstr(MI); 226 B.buildPtrAdd( 227 Dst, NewGVDst, 228 B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset))); 229 } 230 231 // Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y)) 232 // Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1)) 233 // Similar to performVecReduceAddCombine in SelectionDAG 234 bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, 235 const AArch64Subtarget &STI, 236 std::tuple<Register, Register, bool> &MatchInfo) { 237 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD && 238 "Expected a G_VECREDUCE_ADD instruction"); 239 assert(STI.hasDotProd() && "Target should have Dot Product feature"); 240 241 MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI); 242 Register DstReg = MI.getOperand(0).getReg(); 243 Register MidReg = I1->getOperand(0).getReg(); 244 LLT DstTy = MRI.getType(DstReg); 245 LLT MidTy = MRI.getType(MidReg); 246 if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32) 247 return false; 248 249 LLT SrcTy; 250 auto I1Opc = I1->getOpcode(); 251 if (I1Opc == TargetOpcode::G_MUL) { 252 // If result of this has more than 1 use, then there is no point in creating 253 // udot instruction 254 if (!MRI.hasOneNonDBGUse(MidReg)) 255 return false; 256 257 MachineInstr *ExtMI1 = 258 getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI); 259 MachineInstr *ExtMI2 = 260 getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI); 261 LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg()); 262 LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg()); 263 264 if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy) 265 return false; 266 I1Opc = ExtMI1->getOpcode(); 267 SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg()); 268 std::get<0>(MatchInfo) = ExtMI1->getOperand(1).getReg(); 269 std::get<1>(MatchInfo) = ExtMI2->getOperand(1).getReg(); 270 } else { 271 SrcTy = MRI.getType(I1->getOperand(1).getReg()); 272 std::get<0>(MatchInfo) = I1->getOperand(1).getReg(); 273 std::get<1>(MatchInfo) = 0; 274 } 275 276 if (I1Opc == TargetOpcode::G_ZEXT) 277 std::get<2>(MatchInfo) = 0; 278 else if (I1Opc == TargetOpcode::G_SEXT) 279 std::get<2>(MatchInfo) = 1; 280 else 281 return false; 282 283 if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0) 284 return false; 285 286 return true; 287 } 288 289 void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, 290 MachineIRBuilder &Builder, 291 GISelChangeObserver &Observer, 292 const AArch64Subtarget &STI, 293 std::tuple<Register, Register, bool> &MatchInfo) { 294 assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD && 295 "Expected a G_VECREDUCE_ADD instruction"); 296 assert(STI.hasDotProd() && "Target should have Dot Product feature"); 297 298 // Initialise the variables 299 unsigned DotOpcode = 300 std::get<2>(MatchInfo) ? AArch64::G_SDOT : AArch64::G_UDOT; 301 Register Ext1SrcReg = std::get<0>(MatchInfo); 302 303 // If there is one source register, create a vector of 0s as the second 304 // source register 305 Register Ext2SrcReg; 306 if (std::get<1>(MatchInfo) == 0) 307 Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1) 308 ->getOperand(0) 309 .getReg(); 310 else 311 Ext2SrcReg = std::get<1>(MatchInfo); 312 313 // Find out how many DOT instructions are needed 314 LLT SrcTy = MRI.getType(Ext1SrcReg); 315 LLT MidTy; 316 unsigned NumOfDotMI; 317 if (SrcTy.getNumElements() % 16 == 0) { 318 NumOfDotMI = SrcTy.getNumElements() / 16; 319 MidTy = LLT::fixed_vector(4, 32); 320 } else if (SrcTy.getNumElements() % 8 == 0) { 321 NumOfDotMI = SrcTy.getNumElements() / 8; 322 MidTy = LLT::fixed_vector(2, 32); 323 } else { 324 llvm_unreachable("Source type number of elements is not multiple of 8"); 325 } 326 327 // Handle case where one DOT instruction is needed 328 if (NumOfDotMI == 1) { 329 auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg(); 330 auto Dot = Builder.buildInstr(DotOpcode, {MidTy}, 331 {Zeroes, Ext1SrcReg, Ext2SrcReg}); 332 Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0)); 333 } else { 334 // If not pad the last v8 element with 0s to a v16 335 SmallVector<Register, 4> Ext1UnmergeReg; 336 SmallVector<Register, 4> Ext2UnmergeReg; 337 if (SrcTy.getNumElements() % 16 != 0) { 338 // Unmerge source to v8i8, append a new v8i8 of 0s and the merge to v16s 339 SmallVector<Register, 4> PadUnmergeDstReg1; 340 SmallVector<Register, 4> PadUnmergeDstReg2; 341 unsigned NumOfVec = SrcTy.getNumElements() / 8; 342 343 // Unmerge the source to v8i8 344 MachineInstr *PadUnmerge1 = 345 Builder.buildUnmerge(LLT::fixed_vector(8, 8), Ext1SrcReg); 346 MachineInstr *PadUnmerge2 = 347 Builder.buildUnmerge(LLT::fixed_vector(8, 8), Ext2SrcReg); 348 for (unsigned i = 0; i < NumOfVec; i++) { 349 PadUnmergeDstReg1.push_back(PadUnmerge1->getOperand(i).getReg()); 350 PadUnmergeDstReg2.push_back(PadUnmerge2->getOperand(i).getReg()); 351 } 352 353 // Pad the vectors with a v8i8 constant of 0s 354 MachineInstr *v8Zeroes = 355 Builder.buildConstant(LLT::fixed_vector(8, 8), 0); 356 PadUnmergeDstReg1.push_back(v8Zeroes->getOperand(0).getReg()); 357 PadUnmergeDstReg2.push_back(v8Zeroes->getOperand(0).getReg()); 358 359 // Merge them all back to v16i8 360 NumOfVec = (NumOfVec + 1) / 2; 361 for (unsigned i = 0; i < NumOfVec; i++) { 362 Ext1UnmergeReg.push_back( 363 Builder 364 .buildMergeLikeInstr( 365 LLT::fixed_vector(16, 8), 366 {PadUnmergeDstReg1[i * 2], PadUnmergeDstReg1[(i * 2) + 1]}) 367 .getReg(0)); 368 Ext2UnmergeReg.push_back( 369 Builder 370 .buildMergeLikeInstr( 371 LLT::fixed_vector(16, 8), 372 {PadUnmergeDstReg2[i * 2], PadUnmergeDstReg2[(i * 2) + 1]}) 373 .getReg(0)); 374 } 375 } else { 376 // Unmerge the source vectors to v16i8 377 MachineInstr *Ext1Unmerge = 378 Builder.buildUnmerge(LLT::fixed_vector(16, 8), Ext1SrcReg); 379 MachineInstr *Ext2Unmerge = 380 Builder.buildUnmerge(LLT::fixed_vector(16, 8), Ext2SrcReg); 381 for (unsigned i = 0, e = SrcTy.getNumElements() / 16; i < e; i++) { 382 Ext1UnmergeReg.push_back(Ext1Unmerge->getOperand(i).getReg()); 383 Ext2UnmergeReg.push_back(Ext2Unmerge->getOperand(i).getReg()); 384 } 385 } 386 387 // Build the UDOT instructions 388 SmallVector<Register, 2> DotReg; 389 unsigned NumElements = 0; 390 for (unsigned i = 0; i < Ext1UnmergeReg.size(); i++) { 391 LLT ZeroesLLT; 392 // Check if it is 16 or 8 elements. Set Zeroes to the according size 393 if (MRI.getType(Ext1UnmergeReg[i]).getNumElements() == 16) { 394 ZeroesLLT = LLT::fixed_vector(4, 32); 395 NumElements += 4; 396 } else { 397 ZeroesLLT = LLT::fixed_vector(2, 32); 398 NumElements += 2; 399 } 400 auto Zeroes = Builder.buildConstant(ZeroesLLT, 0)->getOperand(0).getReg(); 401 DotReg.push_back( 402 Builder 403 .buildInstr(DotOpcode, {MRI.getType(Zeroes)}, 404 {Zeroes, Ext1UnmergeReg[i], Ext2UnmergeReg[i]}) 405 .getReg(0)); 406 } 407 408 // Merge the output 409 auto ConcatMI = 410 Builder.buildConcatVectors(LLT::fixed_vector(NumElements, 32), DotReg); 411 412 // Put it through a vector reduction 413 Builder.buildVecReduceAdd(MI.getOperand(0).getReg(), 414 ConcatMI->getOperand(0).getReg()); 415 } 416 417 // Erase the dead instructions 418 MI.eraseFromParent(); 419 } 420 421 bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B, 422 CombinerHelper &Helper, GISelChangeObserver &Observer) { 423 // Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if 424 // result is only used in the no-overflow case. It is restricted to cases 425 // where we know that the high-bits of the operands are 0. If there's an 426 // overflow, then the 9th or 17th bit must be set, which can be checked 427 // using TBNZ. 428 // 429 // Change (for UADDOs on 8 and 16 bits): 430 // 431 // %z0 = G_ASSERT_ZEXT _ 432 // %op0 = G_TRUNC %z0 433 // %z1 = G_ASSERT_ZEXT _ 434 // %op1 = G_TRUNC %z1 435 // %val, %cond = G_UADDO %op0, %op1 436 // G_BRCOND %cond, %error.bb 437 // 438 // error.bb: 439 // (no successors and no uses of %val) 440 // 441 // To: 442 // 443 // %z0 = G_ASSERT_ZEXT _ 444 // %z1 = G_ASSERT_ZEXT _ 445 // %add = G_ADD %z0, %z1 446 // %val = G_TRUNC %add 447 // %bit = G_AND %add, 1 << scalar-size-in-bits(%op1) 448 // %cond = G_ICMP NE, %bit, 0 449 // G_BRCOND %cond, %error.bb 450 451 auto &MRI = *B.getMRI(); 452 453 MachineOperand *DefOp0 = MRI.getOneDef(MI.getOperand(2).getReg()); 454 MachineOperand *DefOp1 = MRI.getOneDef(MI.getOperand(3).getReg()); 455 Register Op0Wide; 456 Register Op1Wide; 457 if (!mi_match(DefOp0->getParent(), MRI, m_GTrunc(m_Reg(Op0Wide))) || 458 !mi_match(DefOp1->getParent(), MRI, m_GTrunc(m_Reg(Op1Wide)))) 459 return false; 460 LLT WideTy0 = MRI.getType(Op0Wide); 461 LLT WideTy1 = MRI.getType(Op1Wide); 462 Register ResVal = MI.getOperand(0).getReg(); 463 LLT OpTy = MRI.getType(ResVal); 464 MachineInstr *Op0WideDef = MRI.getVRegDef(Op0Wide); 465 MachineInstr *Op1WideDef = MRI.getVRegDef(Op1Wide); 466 467 unsigned OpTySize = OpTy.getScalarSizeInBits(); 468 // First check that the G_TRUNC feeding the G_UADDO are no-ops, because the 469 // inputs have been zero-extended. 470 if (Op0WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT || 471 Op1WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT || 472 OpTySize != Op0WideDef->getOperand(2).getImm() || 473 OpTySize != Op1WideDef->getOperand(2).getImm()) 474 return false; 475 476 // Only scalar UADDO with either 8 or 16 bit operands are handled. 477 if (!WideTy0.isScalar() || !WideTy1.isScalar() || WideTy0 != WideTy1 || 478 OpTySize >= WideTy0.getScalarSizeInBits() || 479 (OpTySize != 8 && OpTySize != 16)) 480 return false; 481 482 // The overflow-status result must be used by a branch only. 483 Register ResStatus = MI.getOperand(1).getReg(); 484 if (!MRI.hasOneNonDBGUse(ResStatus)) 485 return false; 486 MachineInstr *CondUser = &*MRI.use_instr_nodbg_begin(ResStatus); 487 if (CondUser->getOpcode() != TargetOpcode::G_BRCOND) 488 return false; 489 490 // Make sure the computed result is only used in the no-overflow blocks. 491 MachineBasicBlock *CurrentMBB = MI.getParent(); 492 MachineBasicBlock *FailMBB = CondUser->getOperand(1).getMBB(); 493 if (!FailMBB->succ_empty() || CondUser->getParent() != CurrentMBB) 494 return false; 495 if (any_of(MRI.use_nodbg_instructions(ResVal), 496 [&MI, FailMBB, CurrentMBB](MachineInstr &I) { 497 return &MI != &I && 498 (I.getParent() == FailMBB || I.getParent() == CurrentMBB); 499 })) 500 return false; 501 502 // Remove G_ADDO. 503 B.setInstrAndDebugLoc(*MI.getNextNode()); 504 MI.eraseFromParent(); 505 506 // Emit wide add. 507 Register AddDst = MRI.cloneVirtualRegister(Op0Wide); 508 B.buildInstr(TargetOpcode::G_ADD, {AddDst}, {Op0Wide, Op1Wide}); 509 510 // Emit check of the 9th or 17th bit and update users (the branch). This will 511 // later be folded to TBNZ. 512 Register CondBit = MRI.cloneVirtualRegister(Op0Wide); 513 B.buildAnd( 514 CondBit, AddDst, 515 B.buildConstant(LLT::scalar(32), OpTySize == 8 ? 1 << 8 : 1 << 16)); 516 B.buildICmp(CmpInst::ICMP_NE, ResStatus, CondBit, 517 B.buildConstant(LLT::scalar(32), 0)); 518 519 // Update ZEXts users of the result value. Because all uses are in the 520 // no-overflow case, we know that the top bits are 0 and we can ignore ZExts. 521 B.buildZExtOrTrunc(ResVal, AddDst); 522 for (MachineOperand &U : make_early_inc_range(MRI.use_operands(ResVal))) { 523 Register WideReg; 524 if (mi_match(U.getParent(), MRI, m_GZExt(m_Reg(WideReg)))) { 525 auto OldR = U.getParent()->getOperand(0).getReg(); 526 Observer.erasingInstr(*U.getParent()); 527 U.getParent()->eraseFromParent(); 528 Helper.replaceRegWith(MRI, OldR, AddDst); 529 } 530 } 531 532 return true; 533 } 534 535 class AArch64PreLegalizerCombinerImpl : public Combiner { 536 protected: 537 // TODO: Make CombinerHelper methods const. 538 mutable CombinerHelper Helper; 539 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig; 540 const AArch64Subtarget &STI; 541 542 public: 543 AArch64PreLegalizerCombinerImpl( 544 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, 545 GISelKnownBits &KB, GISelCSEInfo *CSEInfo, 546 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig, 547 const AArch64Subtarget &STI, MachineDominatorTree *MDT, 548 const LegalizerInfo *LI); 549 550 static const char *getName() { return "AArch6400PreLegalizerCombiner"; } 551 552 bool tryCombineAll(MachineInstr &I) const override; 553 554 bool tryCombineAllImpl(MachineInstr &I) const; 555 556 private: 557 #define GET_GICOMBINER_CLASS_MEMBERS 558 #include "AArch64GenPreLegalizeGICombiner.inc" 559 #undef GET_GICOMBINER_CLASS_MEMBERS 560 }; 561 562 #define GET_GICOMBINER_IMPL 563 #include "AArch64GenPreLegalizeGICombiner.inc" 564 #undef GET_GICOMBINER_IMPL 565 566 AArch64PreLegalizerCombinerImpl::AArch64PreLegalizerCombinerImpl( 567 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, 568 GISelKnownBits &KB, GISelCSEInfo *CSEInfo, 569 const AArch64PreLegalizerCombinerImplRuleConfig &RuleConfig, 570 const AArch64Subtarget &STI, MachineDominatorTree *MDT, 571 const LegalizerInfo *LI) 572 : Combiner(MF, CInfo, TPC, &KB, CSEInfo), 573 Helper(Observer, B, /*IsPreLegalize*/ true, &KB, MDT, LI), 574 RuleConfig(RuleConfig), STI(STI), 575 #define GET_GICOMBINER_CONSTRUCTOR_INITS 576 #include "AArch64GenPreLegalizeGICombiner.inc" 577 #undef GET_GICOMBINER_CONSTRUCTOR_INITS 578 { 579 } 580 581 bool AArch64PreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const { 582 if (tryCombineAllImpl(MI)) 583 return true; 584 585 unsigned Opc = MI.getOpcode(); 586 switch (Opc) { 587 case TargetOpcode::G_CONCAT_VECTORS: 588 return Helper.tryCombineConcatVectors(MI); 589 case TargetOpcode::G_SHUFFLE_VECTOR: 590 return Helper.tryCombineShuffleVector(MI); 591 case TargetOpcode::G_UADDO: 592 return tryToSimplifyUADDO(MI, B, Helper, Observer); 593 case TargetOpcode::G_MEMCPY_INLINE: 594 return Helper.tryEmitMemcpyInline(MI); 595 case TargetOpcode::G_MEMCPY: 596 case TargetOpcode::G_MEMMOVE: 597 case TargetOpcode::G_MEMSET: { 598 // If we're at -O0 set a maxlen of 32 to inline, otherwise let the other 599 // heuristics decide. 600 unsigned MaxLen = CInfo.EnableOpt ? 0 : 32; 601 // Try to inline memcpy type calls if optimizations are enabled. 602 if (Helper.tryCombineMemCpyFamily(MI, MaxLen)) 603 return true; 604 if (Opc == TargetOpcode::G_MEMSET) 605 return llvm::AArch64GISelUtils::tryEmitBZero(MI, B, CInfo.EnableMinSize); 606 return false; 607 } 608 } 609 610 return false; 611 } 612 613 // Pass boilerplate 614 // ================ 615 616 class AArch64PreLegalizerCombiner : public MachineFunctionPass { 617 public: 618 static char ID; 619 620 AArch64PreLegalizerCombiner(); 621 622 StringRef getPassName() const override { 623 return "AArch64PreLegalizerCombiner"; 624 } 625 626 bool runOnMachineFunction(MachineFunction &MF) override; 627 628 void getAnalysisUsage(AnalysisUsage &AU) const override; 629 630 private: 631 AArch64PreLegalizerCombinerImplRuleConfig RuleConfig; 632 }; 633 } // end anonymous namespace 634 635 void AArch64PreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { 636 AU.addRequired<TargetPassConfig>(); 637 AU.setPreservesCFG(); 638 getSelectionDAGFallbackAnalysisUsage(AU); 639 AU.addRequired<GISelKnownBitsAnalysis>(); 640 AU.addPreserved<GISelKnownBitsAnalysis>(); 641 AU.addRequired<MachineDominatorTree>(); 642 AU.addPreserved<MachineDominatorTree>(); 643 AU.addRequired<GISelCSEAnalysisWrapperPass>(); 644 AU.addPreserved<GISelCSEAnalysisWrapperPass>(); 645 MachineFunctionPass::getAnalysisUsage(AU); 646 } 647 648 AArch64PreLegalizerCombiner::AArch64PreLegalizerCombiner() 649 : MachineFunctionPass(ID) { 650 initializeAArch64PreLegalizerCombinerPass(*PassRegistry::getPassRegistry()); 651 652 if (!RuleConfig.parseCommandLineOption()) 653 report_fatal_error("Invalid rule identifier"); 654 } 655 656 bool AArch64PreLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { 657 if (MF.getProperties().hasProperty( 658 MachineFunctionProperties::Property::FailedISel)) 659 return false; 660 auto &TPC = getAnalysis<TargetPassConfig>(); 661 662 // Enable CSE. 663 GISelCSEAnalysisWrapper &Wrapper = 664 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper(); 665 auto *CSEInfo = &Wrapper.get(TPC.getCSEConfig()); 666 667 const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>(); 668 const auto *LI = ST.getLegalizerInfo(); 669 670 const Function &F = MF.getFunction(); 671 bool EnableOpt = 672 MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F); 673 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); 674 MachineDominatorTree *MDT = &getAnalysis<MachineDominatorTree>(); 675 CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, 676 /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(), 677 F.hasMinSize()); 678 AArch64PreLegalizerCombinerImpl Impl(MF, CInfo, &TPC, *KB, CSEInfo, 679 RuleConfig, ST, MDT, LI); 680 return Impl.combineMachineInstrs(); 681 } 682 683 char AArch64PreLegalizerCombiner::ID = 0; 684 INITIALIZE_PASS_BEGIN(AArch64PreLegalizerCombiner, DEBUG_TYPE, 685 "Combine AArch64 machine instrs before legalization", 686 false, false) 687 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 688 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) 689 INITIALIZE_PASS_DEPENDENCY(GISelCSEAnalysisWrapperPass) 690 INITIALIZE_PASS_END(AArch64PreLegalizerCombiner, DEBUG_TYPE, 691 "Combine AArch64 machine instrs before legalization", false, 692 false) 693 694 namespace llvm { 695 FunctionPass *createAArch64PreLegalizerCombiner() { 696 return new AArch64PreLegalizerCombiner(); 697 } 698 } // end namespace llvm 699