1 //===- AggressiveInstCombine.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the aggressive expression pattern combiner classes.
10 // Currently, it handles expression patterns for:
11 //  * Truncate instruction
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
16 #include "AggressiveInstCombineInternal.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AliasAnalysis.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/BasicAliasAnalysis.h"
21 #include "llvm/Analysis/ConstantFolding.h"
22 #include "llvm/Analysis/GlobalsModRef.h"
23 #include "llvm/Analysis/TargetLibraryInfo.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/PatternMatch.h"
31 #include "llvm/Transforms/Utils/BuildLibCalls.h"
32 #include "llvm/Transforms/Utils/Local.h"
33 
34 using namespace llvm;
35 using namespace PatternMatch;
36 
37 #define DEBUG_TYPE "aggressive-instcombine"
38 
39 STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded");
40 STATISTIC(NumGuardedRotates,
41           "Number of guarded rotates transformed into funnel shifts");
42 STATISTIC(NumGuardedFunnelShifts,
43           "Number of guarded funnel shifts transformed into funnel shifts");
44 STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");
45 
46 static cl::opt<unsigned> MaxInstrsToScan(
47     "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden,
48     cl::desc("Max number of instructions to scan for aggressive instcombine."));
49 
50 /// Match a pattern for a bitwise funnel/rotate operation that partially guards
51 /// against undefined behavior by branching around the funnel-shift/rotation
52 /// when the shift amount is 0.
53 static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
54   if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2)
55     return false;
56 
57   // As with the one-use checks below, this is not strictly necessary, but we
58   // are being cautious to avoid potential perf regressions on targets that
59   // do not actually have a funnel/rotate instruction (where the funnel shift
60   // would be expanded back into math/shift/logic ops).
61   if (!isPowerOf2_32(I.getType()->getScalarSizeInBits()))
62     return false;
63 
64   // Match V to funnel shift left/right and capture the source operands and
65   // shift amount.
66   auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
67                              Value *&ShAmt) {
68     unsigned Width = V->getType()->getScalarSizeInBits();
69 
70     // fshl(ShVal0, ShVal1, ShAmt)
71     //  == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
72     if (match(V, m_OneUse(m_c_Or(
73                      m_Shl(m_Value(ShVal0), m_Value(ShAmt)),
74                      m_LShr(m_Value(ShVal1),
75                             m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) {
76         return Intrinsic::fshl;
77     }
78 
79     // fshr(ShVal0, ShVal1, ShAmt)
80     //  == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
81     if (match(V,
82               m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width),
83                                                            m_Value(ShAmt))),
84                               m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) {
85         return Intrinsic::fshr;
86     }
87 
88     return Intrinsic::not_intrinsic;
89   };
90 
91   // One phi operand must be a funnel/rotate operation, and the other phi
92   // operand must be the source value of that funnel/rotate operation:
93   // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
94   // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
95   // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
96   PHINode &Phi = cast<PHINode>(I);
97   unsigned FunnelOp = 0, GuardOp = 1;
98   Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1);
99   Value *ShVal0, *ShVal1, *ShAmt;
100   Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt);
101   if (IID == Intrinsic::not_intrinsic ||
102       (IID == Intrinsic::fshl && ShVal0 != P1) ||
103       (IID == Intrinsic::fshr && ShVal1 != P1)) {
104     IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt);
105     if (IID == Intrinsic::not_intrinsic ||
106         (IID == Intrinsic::fshl && ShVal0 != P0) ||
107         (IID == Intrinsic::fshr && ShVal1 != P0))
108       return false;
109     assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
110            "Pattern must match funnel shift left or right");
111     std::swap(FunnelOp, GuardOp);
112   }
113 
114   // The incoming block with our source operand must be the "guard" block.
115   // That must contain a cmp+branch to avoid the funnel/rotate when the shift
116   // amount is equal to 0. The other incoming block is the block with the
117   // funnel/rotate.
118   BasicBlock *GuardBB = Phi.getIncomingBlock(GuardOp);
119   BasicBlock *FunnelBB = Phi.getIncomingBlock(FunnelOp);
120   Instruction *TermI = GuardBB->getTerminator();
121 
122   // Ensure that the shift values dominate each block.
123   if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI))
124     return false;
125 
126   ICmpInst::Predicate Pred;
127   BasicBlock *PhiBB = Phi.getParent();
128   if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()),
129                          m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB))))
130     return false;
131 
132   if (Pred != CmpInst::ICMP_EQ)
133     return false;
134 
135   IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
136 
137   if (ShVal0 == ShVal1)
138     ++NumGuardedRotates;
139   else
140     ++NumGuardedFunnelShifts;
141 
142   // If this is not a rotate then the select was blocking poison from the
143   // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
144   bool IsFshl = IID == Intrinsic::fshl;
145   if (ShVal0 != ShVal1) {
146     if (IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal1))
147       ShVal1 = Builder.CreateFreeze(ShVal1);
148     else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal0))
149       ShVal0 = Builder.CreateFreeze(ShVal0);
150   }
151 
152   // We matched a variation of this IR pattern:
153   // GuardBB:
154   //   %cmp = icmp eq i32 %ShAmt, 0
155   //   br i1 %cmp, label %PhiBB, label %FunnelBB
156   // FunnelBB:
157   //   %sub = sub i32 32, %ShAmt
158   //   %shr = lshr i32 %ShVal1, %sub
159   //   %shl = shl i32 %ShVal0, %ShAmt
160   //   %fsh = or i32 %shr, %shl
161   //   br label %PhiBB
162   // PhiBB:
163   //   %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
164   // -->
165   // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
166   Function *F = Intrinsic::getDeclaration(Phi.getModule(), IID, Phi.getType());
167   Phi.replaceAllUsesWith(Builder.CreateCall(F, {ShVal0, ShVal1, ShAmt}));
168   return true;
169 }
170 
171 /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
172 /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
173 /// of 'and' ops, then we also need to capture the fact that we saw an
174 /// "and X, 1", so that's an extra return value for that case.
175 struct MaskOps {
176   Value *Root = nullptr;
177   APInt Mask;
178   bool MatchAndChain;
179   bool FoundAnd1 = false;
180 
181   MaskOps(unsigned BitWidth, bool MatchAnds)
182       : Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {}
183 };
184 
185 /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
186 /// chain of 'and' or 'or' instructions looking for shift ops of a common source
187 /// value. Examples:
188 ///   or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
189 /// returns { X, 0x129 }
190 ///   and (and (X >> 1), 1), (X >> 4)
191 /// returns { X, 0x12 }
192 static bool matchAndOrChain(Value *V, MaskOps &MOps) {
193   Value *Op0, *Op1;
194   if (MOps.MatchAndChain) {
195     // Recurse through a chain of 'and' operands. This requires an extra check
196     // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
197     // in the chain to know that all of the high bits are cleared.
198     if (match(V, m_And(m_Value(Op0), m_One()))) {
199       MOps.FoundAnd1 = true;
200       return matchAndOrChain(Op0, MOps);
201     }
202     if (match(V, m_And(m_Value(Op0), m_Value(Op1))))
203       return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
204   } else {
205     // Recurse through a chain of 'or' operands.
206     if (match(V, m_Or(m_Value(Op0), m_Value(Op1))))
207       return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
208   }
209 
210   // We need a shift-right or a bare value representing a compare of bit 0 of
211   // the original source operand.
212   Value *Candidate;
213   const APInt *BitIndex = nullptr;
214   if (!match(V, m_LShr(m_Value(Candidate), m_APInt(BitIndex))))
215     Candidate = V;
216 
217   // Initialize result source operand.
218   if (!MOps.Root)
219     MOps.Root = Candidate;
220 
221   // The shift constant is out-of-range? This code hasn't been simplified.
222   if (BitIndex && BitIndex->uge(MOps.Mask.getBitWidth()))
223     return false;
224 
225   // Fill in the mask bit derived from the shift constant.
226   MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0);
227   return MOps.Root == Candidate;
228 }
229 
230 /// Match patterns that correspond to "any-bits-set" and "all-bits-set".
231 /// These will include a chain of 'or' or 'and'-shifted bits from a
232 /// common source value:
233 /// and (or  (lshr X, C), ...), 1 --> (X & CMask) != 0
234 /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
235 /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
236 /// that differ only with a final 'not' of the result. We expect that final
237 /// 'not' to be folded with the compare that we create here (invert predicate).
238 static bool foldAnyOrAllBitsSet(Instruction &I) {
239   // The 'any-bits-set' ('or' chain) pattern is simpler to match because the
240   // final "and X, 1" instruction must be the final op in the sequence.
241   bool MatchAllBitsSet;
242   if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value())))
243     MatchAllBitsSet = true;
244   else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One())))
245     MatchAllBitsSet = false;
246   else
247     return false;
248 
249   MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet);
250   if (MatchAllBitsSet) {
251     if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1)
252       return false;
253   } else {
254     if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps))
255       return false;
256   }
257 
258   // The pattern was found. Create a masked compare that replaces all of the
259   // shift and logic ops.
260   IRBuilder<> Builder(&I);
261   Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask);
262   Value *And = Builder.CreateAnd(MOps.Root, Mask);
263   Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask)
264                                : Builder.CreateIsNotNull(And);
265   Value *Zext = Builder.CreateZExt(Cmp, I.getType());
266   I.replaceAllUsesWith(Zext);
267   ++NumAnyOrAllBitsSet;
268   return true;
269 }
270 
271 // Try to recognize below function as popcount intrinsic.
272 // This is the "best" algorithm from
273 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
274 // Also used in TargetLowering::expandCTPOP().
275 //
276 // int popcount(unsigned int i) {
277 //   i = i - ((i >> 1) & 0x55555555);
278 //   i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
279 //   i = ((i + (i >> 4)) & 0x0F0F0F0F);
280 //   return (i * 0x01010101) >> 24;
281 // }
282 static bool tryToRecognizePopCount(Instruction &I) {
283   if (I.getOpcode() != Instruction::LShr)
284     return false;
285 
286   Type *Ty = I.getType();
287   if (!Ty->isIntOrIntVectorTy())
288     return false;
289 
290   unsigned Len = Ty->getScalarSizeInBits();
291   // FIXME: fix Len == 8 and other irregular type lengths.
292   if (!(Len <= 128 && Len > 8 && Len % 8 == 0))
293     return false;
294 
295   APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55));
296   APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33));
297   APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F));
298   APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01));
299   APInt MaskShift = APInt(Len, Len - 8);
300 
301   Value *Op0 = I.getOperand(0);
302   Value *Op1 = I.getOperand(1);
303   Value *MulOp0;
304   // Matching "(i * 0x01010101...) >> 24".
305   if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) &&
306       match(Op1, m_SpecificInt(MaskShift))) {
307     Value *ShiftOp0;
308     // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
309     if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)),
310                                     m_Deferred(ShiftOp0)),
311                             m_SpecificInt(Mask0F)))) {
312       Value *AndOp0;
313       // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
314       if (match(ShiftOp0,
315                 m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)),
316                         m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)),
317                               m_SpecificInt(Mask33))))) {
318         Value *Root, *SubOp1;
319         // Matching "i - ((i >> 1) & 0x55555555...)".
320         if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) &&
321             match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)),
322                                 m_SpecificInt(Mask55)))) {
323           LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
324           IRBuilder<> Builder(&I);
325           Function *Func = Intrinsic::getDeclaration(
326               I.getModule(), Intrinsic::ctpop, I.getType());
327           I.replaceAllUsesWith(Builder.CreateCall(Func, {Root}));
328           ++NumPopCountRecognized;
329           return true;
330         }
331       }
332     }
333   }
334 
335   return false;
336 }
337 
338 /// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and
339 /// C2 saturate the value of the fp conversion. The transform is not reversable
340 /// as the fptosi.sat is more defined than the input - all values produce a
341 /// valid value for the fptosi.sat, where as some produce poison for original
342 /// that were out of range of the integer conversion. The reversed pattern may
343 /// use fmax and fmin instead. As we cannot directly reverse the transform, and
344 /// it is not always profitable, we make it conditional on the cost being
345 /// reported as lower by TTI.
346 static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
347   // Look for min(max(fptosi, converting to fptosi_sat.
348   Value *In;
349   const APInt *MinC, *MaxC;
350   if (!match(&I, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In))),
351                                         m_APInt(MinC))),
352                         m_APInt(MaxC))) &&
353       !match(&I, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In))),
354                                         m_APInt(MaxC))),
355                         m_APInt(MinC))))
356     return false;
357 
358   // Check that the constants clamp a saturate.
359   if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1)
360     return false;
361 
362   Type *IntTy = I.getType();
363   Type *FpTy = In->getType();
364   Type *SatTy =
365       IntegerType::get(IntTy->getContext(), (*MinC + 1).exactLogBase2() + 1);
366   if (auto *VecTy = dyn_cast<VectorType>(IntTy))
367     SatTy = VectorType::get(SatTy, VecTy->getElementCount());
368 
369   // Get the cost of the intrinsic, and check that against the cost of
370   // fptosi+smin+smax
371   InstructionCost SatCost = TTI.getIntrinsicInstrCost(
372       IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}),
373       TTI::TCK_RecipThroughput);
374   SatCost += TTI.getCastInstrCost(Instruction::SExt, SatTy, IntTy,
375                                   TTI::CastContextHint::None,
376                                   TTI::TCK_RecipThroughput);
377 
378   InstructionCost MinMaxCost = TTI.getCastInstrCost(
379       Instruction::FPToSI, IntTy, FpTy, TTI::CastContextHint::None,
380       TTI::TCK_RecipThroughput);
381   MinMaxCost += TTI.getIntrinsicInstrCost(
382       IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}),
383       TTI::TCK_RecipThroughput);
384   MinMaxCost += TTI.getIntrinsicInstrCost(
385       IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}),
386       TTI::TCK_RecipThroughput);
387 
388   if (SatCost >= MinMaxCost)
389     return false;
390 
391   IRBuilder<> Builder(&I);
392   Function *Fn = Intrinsic::getDeclaration(I.getModule(), Intrinsic::fptosi_sat,
393                                            {SatTy, FpTy});
394   Value *Sat = Builder.CreateCall(Fn, In);
395   I.replaceAllUsesWith(Builder.CreateSExt(Sat, IntTy));
396   return true;
397 }
398 
399 /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
400 /// pessimistic codegen that has to account for setting errno and can enable
401 /// vectorization.
402 static bool foldSqrt(Instruction &I, TargetTransformInfo &TTI,
403                      TargetLibraryInfo &TLI, AssumptionCache &AC,
404                      DominatorTree &DT) {
405   // Match a call to sqrt mathlib function.
406   auto *Call = dyn_cast<CallInst>(&I);
407   if (!Call)
408     return false;
409 
410   Module *M = Call->getModule();
411   LibFunc Func;
412   if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func))
413     return false;
414 
415   if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl)
416     return false;
417 
418   // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
419   // (because NNAN or the operand arg must not be less than -0.0) and (2) we
420   // would not end up lowering to a libcall anyway (which could change the value
421   // of errno), then:
422   // (1) errno won't be set.
423   // (2) it is safe to convert this to an intrinsic call.
424   Type *Ty = Call->getType();
425   Value *Arg = Call->getArgOperand(0);
426   if (TTI.haveFastSqrt(Ty) &&
427       (Call->hasNoNaNs() ||
428        cannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI, 0, &AC, &I,
429                                    &DT))) {
430     IRBuilder<> Builder(&I);
431     IRBuilderBase::FastMathFlagGuard Guard(Builder);
432     Builder.setFastMathFlags(Call->getFastMathFlags());
433 
434     Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty);
435     Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt");
436     I.replaceAllUsesWith(NewSqrt);
437 
438     // Explicitly erase the old call because a call with side effects is not
439     // trivially dead.
440     I.eraseFromParent();
441     return true;
442   }
443 
444   return false;
445 }
446 
447 // Check if this array of constants represents a cttz table.
448 // Iterate over the elements from \p Table by trying to find/match all
449 // the numbers from 0 to \p InputBits that should represent cttz results.
450 static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
451                         uint64_t Shift, uint64_t InputBits) {
452   unsigned Length = Table.getNumElements();
453   if (Length < InputBits || Length > InputBits * 2)
454     return false;
455 
456   APInt Mask = APInt::getBitsSetFrom(InputBits, Shift);
457   unsigned Matched = 0;
458 
459   for (unsigned i = 0; i < Length; i++) {
460     uint64_t Element = Table.getElementAsInteger(i);
461     if (Element >= InputBits)
462       continue;
463 
464     // Check if \p Element matches a concrete answer. It could fail for some
465     // elements that are never accessed, so we keep iterating over each element
466     // from the table. The number of matched elements should be equal to the
467     // number of potential right answers which is \p InputBits actually.
468     if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i)
469       Matched++;
470   }
471 
472   return Matched == InputBits;
473 }
474 
475 // Try to recognize table-based ctz implementation.
476 // E.g., an example in C (for more cases please see the llvm/tests):
477 // int f(unsigned x) {
478 //    static const char table[32] =
479 //      {0, 1, 28, 2, 29, 14, 24, 3, 30,
480 //       22, 20, 15, 25, 17, 4, 8, 31, 27,
481 //       13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
482 //    return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
483 // }
484 // this can be lowered to `cttz` instruction.
485 // There is also a special case when the element is 0.
486 //
487 // Here are some examples or LLVM IR for a 64-bit target:
488 //
489 // CASE 1:
490 // %sub = sub i32 0, %x
491 // %and = and i32 %sub, %x
492 // %mul = mul i32 %and, 125613361
493 // %shr = lshr i32 %mul, 27
494 // %idxprom = zext i32 %shr to i64
495 // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
496 // i64 %idxprom %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
497 //
498 // CASE 2:
499 // %sub = sub i32 0, %x
500 // %and = and i32 %sub, %x
501 // %mul = mul i32 %and, 72416175
502 // %shr = lshr i32 %mul, 26
503 // %idxprom = zext i32 %shr to i64
504 // %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, i64
505 // 0, i64 %idxprom %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
506 //
507 // CASE 3:
508 // %sub = sub i32 0, %x
509 // %and = and i32 %sub, %x
510 // %mul = mul i32 %and, 81224991
511 // %shr = lshr i32 %mul, 27
512 // %idxprom = zext i32 %shr to i64
513 // %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, i64
514 // 0, i64 %idxprom %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
515 //
516 // CASE 4:
517 // %sub = sub i64 0, %x
518 // %and = and i64 %sub, %x
519 // %mul = mul i64 %and, 283881067100198605
520 // %shr = lshr i64 %mul, 58
521 // %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, i64
522 // %shr %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
523 //
524 // All this can be lowered to @llvm.cttz.i32/64 intrinsic.
525 static bool tryToRecognizeTableBasedCttz(Instruction &I) {
526   LoadInst *LI = dyn_cast<LoadInst>(&I);
527   if (!LI)
528     return false;
529 
530   Type *AccessType = LI->getType();
531   if (!AccessType->isIntegerTy())
532     return false;
533 
534   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand());
535   if (!GEP || !GEP->isInBounds() || GEP->getNumIndices() != 2)
536     return false;
537 
538   if (!GEP->getSourceElementType()->isArrayTy())
539     return false;
540 
541   uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements();
542   if (ArraySize != 32 && ArraySize != 64)
543     return false;
544 
545   GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand());
546   if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
547     return false;
548 
549   ConstantDataArray *ConstData =
550       dyn_cast<ConstantDataArray>(GVTable->getInitializer());
551   if (!ConstData)
552     return false;
553 
554   if (!match(GEP->idx_begin()->get(), m_ZeroInt()))
555     return false;
556 
557   Value *Idx2 = std::next(GEP->idx_begin())->get();
558   Value *X1;
559   uint64_t MulConst, ShiftConst;
560   // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will
561   // probably fail for other (e.g. 32-bit) targets.
562   if (!match(Idx2, m_ZExtOrSelf(
563                        m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)),
564                                     m_ConstantInt(MulConst)),
565                               m_ConstantInt(ShiftConst)))))
566     return false;
567 
568   unsigned InputBits = X1->getType()->getScalarSizeInBits();
569   if (InputBits != 32 && InputBits != 64)
570     return false;
571 
572   // Shift should extract top 5..7 bits.
573   if (InputBits - Log2_32(InputBits) != ShiftConst &&
574       InputBits - Log2_32(InputBits) - 1 != ShiftConst)
575     return false;
576 
577   if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits))
578     return false;
579 
580   auto ZeroTableElem = ConstData->getElementAsInteger(0);
581   bool DefinedForZero = ZeroTableElem == InputBits;
582 
583   IRBuilder<> B(LI);
584   ConstantInt *BoolConst = B.getInt1(!DefinedForZero);
585   Type *XType = X1->getType();
586   auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst});
587   Value *ZExtOrTrunc = nullptr;
588 
589   if (DefinedForZero) {
590     ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType);
591   } else {
592     // If the value in elem 0 isn't the same as InputBits, we still want to
593     // produce the value from the table.
594     auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0));
595     auto Select =
596         B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz);
597 
598     // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
599     // it should be handled as: `cttz(x) & (typeSize - 1)`.
600 
601     ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType);
602   }
603 
604   LI->replaceAllUsesWith(ZExtOrTrunc);
605 
606   return true;
607 }
608 
609 /// This is used by foldLoadsRecursive() to capture a Root Load node which is
610 /// of type or(load, load) and recursively build the wide load. Also capture the
611 /// shift amount, zero extend type and loadSize.
612 struct LoadOps {
613   LoadInst *Root = nullptr;
614   LoadInst *RootInsert = nullptr;
615   bool FoundRoot = false;
616   uint64_t LoadSize = 0;
617   const APInt *Shift = nullptr;
618   Type *ZextType;
619   AAMDNodes AATags;
620 };
621 
622 // Identify and Merge consecutive loads recursively which is of the form
623 // (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
624 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
625 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
626                                AliasAnalysis &AA) {
627   const APInt *ShAmt2 = nullptr;
628   Value *X;
629   Instruction *L1, *L2;
630 
631   // Go to the last node with loads.
632   if (match(V, m_OneUse(m_c_Or(
633                    m_Value(X),
634                    m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
635                                   m_APInt(ShAmt2)))))) ||
636       match(V, m_OneUse(m_Or(m_Value(X),
637                              m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
638     if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
639       // Avoid Partial chain merge.
640       return false;
641   } else
642     return false;
643 
644   // Check if the pattern has loads
645   LoadInst *LI1 = LOps.Root;
646   const APInt *ShAmt1 = LOps.Shift;
647   if (LOps.FoundRoot == false &&
648       (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
649        match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
650                                m_APInt(ShAmt1)))))) {
651     LI1 = dyn_cast<LoadInst>(L1);
652   }
653   LoadInst *LI2 = dyn_cast<LoadInst>(L2);
654 
655   // Check if loads are same, atomic, volatile and having same address space.
656   if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() ||
657       LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace())
658     return false;
659 
660   // Check if Loads come from same BB.
661   if (LI1->getParent() != LI2->getParent())
662     return false;
663 
664   // Find the data layout
665   bool IsBigEndian = DL.isBigEndian();
666 
667   // Check if loads are consecutive and same size.
668   Value *Load1Ptr = LI1->getPointerOperand();
669   APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
670   Load1Ptr =
671       Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1,
672                                                   /* AllowNonInbounds */ true);
673 
674   Value *Load2Ptr = LI2->getPointerOperand();
675   APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0);
676   Load2Ptr =
677       Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2,
678                                                   /* AllowNonInbounds */ true);
679 
680   // Verify if both loads have same base pointers and load sizes are same.
681   uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits();
682   uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits();
683   if (Load1Ptr != Load2Ptr || LoadSize1 != LoadSize2)
684     return false;
685 
686   // Support Loadsizes greater or equal to 8bits and only power of 2.
687   if (LoadSize1 < 8 || !isPowerOf2_64(LoadSize1))
688     return false;
689 
690   // Alias Analysis to check for stores b/w the loads.
691   LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2;
692   MemoryLocation Loc;
693   if (!Start->comesBefore(End)) {
694     std::swap(Start, End);
695     Loc = MemoryLocation::get(End);
696     if (LOps.FoundRoot)
697       Loc = Loc.getWithNewSize(LOps.LoadSize);
698   } else
699     Loc = MemoryLocation::get(End);
700   unsigned NumScanned = 0;
701   for (Instruction &Inst :
702        make_range(Start->getIterator(), End->getIterator())) {
703     if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc)))
704       return false;
705     if (++NumScanned > MaxInstrsToScan)
706       return false;
707   }
708 
709   // Make sure Load with lower Offset is at LI1
710   bool Reverse = false;
711   if (Offset2.slt(Offset1)) {
712     std::swap(LI1, LI2);
713     std::swap(ShAmt1, ShAmt2);
714     std::swap(Offset1, Offset2);
715     std::swap(Load1Ptr, Load2Ptr);
716     std::swap(LoadSize1, LoadSize2);
717     Reverse = true;
718   }
719 
720   // Big endian swap the shifts
721   if (IsBigEndian)
722     std::swap(ShAmt1, ShAmt2);
723 
724   // Find Shifts values.
725   uint64_t Shift1 = 0, Shift2 = 0;
726   if (ShAmt1)
727     Shift1 = ShAmt1->getZExtValue();
728   if (ShAmt2)
729     Shift2 = ShAmt2->getZExtValue();
730 
731   // First load is always LI1. This is where we put the new load.
732   // Use the merged load size available from LI1 for forward loads.
733   if (LOps.FoundRoot) {
734     if (!Reverse)
735       LoadSize1 = LOps.LoadSize;
736     else
737       LoadSize2 = LOps.LoadSize;
738   }
739 
740   // Verify if shift amount and load index aligns and verifies that loads
741   // are consecutive.
742   uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
743   uint64_t PrevSize =
744       DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
745   if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
746     return false;
747 
748   // Update LOps
749   AAMDNodes AATags1 = LOps.AATags;
750   AAMDNodes AATags2 = LI2->getAAMetadata();
751   if (LOps.FoundRoot == false) {
752     LOps.FoundRoot = true;
753     AATags1 = LI1->getAAMetadata();
754   }
755   LOps.LoadSize = LoadSize1 + LoadSize2;
756   LOps.RootInsert = Start;
757 
758   // Concatenate the AATags of the Merged Loads.
759   LOps.AATags = AATags1.concat(AATags2);
760 
761   LOps.Root = LI1;
762   LOps.Shift = ShAmt1;
763   LOps.ZextType = X->getType();
764   return true;
765 }
766 
767 // For a given BB instruction, evaluate all loads in the chain that form a
768 // pattern which suggests that the loads can be combined. The one and only use
769 // of the loads is to form a wider load.
770 static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
771                                  TargetTransformInfo &TTI, AliasAnalysis &AA,
772                                  const DominatorTree &DT) {
773   // Only consider load chains of scalar values.
774   if (isa<VectorType>(I.getType()))
775     return false;
776 
777   LoadOps LOps;
778   if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot)
779     return false;
780 
781   IRBuilder<> Builder(&I);
782   LoadInst *NewLoad = nullptr, *LI1 = LOps.Root;
783 
784   IntegerType *WiderType = IntegerType::get(I.getContext(), LOps.LoadSize);
785   // TTI based checks if we want to proceed with wider load
786   bool Allowed = TTI.isTypeLegal(WiderType);
787   if (!Allowed)
788     return false;
789 
790   unsigned AS = LI1->getPointerAddressSpace();
791   unsigned Fast = 0;
792   Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize,
793                                                AS, LI1->getAlign(), &Fast);
794   if (!Allowed || !Fast)
795     return false;
796 
797   // Get the Index and Ptr for the new GEP.
798   Value *Load1Ptr = LI1->getPointerOperand();
799   Builder.SetInsertPoint(LOps.RootInsert);
800   if (!DT.dominates(Load1Ptr, LOps.RootInsert)) {
801     APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
802     Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets(
803         DL, Offset1, /* AllowNonInbounds */ true);
804     Load1Ptr = Builder.CreateGEP(Builder.getInt8Ty(), Load1Ptr,
805                                  Builder.getInt32(Offset1.getZExtValue()));
806   }
807   // Generate wider load.
808   NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(),
809                                       LI1->isVolatile(), "");
810   NewLoad->takeName(LI1);
811   // Set the New Load AATags Metadata.
812   if (LOps.AATags)
813     NewLoad->setAAMetadata(LOps.AATags);
814 
815   Value *NewOp = NewLoad;
816   // Check if zero extend needed.
817   if (LOps.ZextType)
818     NewOp = Builder.CreateZExt(NewOp, LOps.ZextType);
819 
820   // Check if shift needed. We need to shift with the amount of load1
821   // shift if not zero.
822   if (LOps.Shift)
823     NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
824   I.replaceAllUsesWith(NewOp);
825 
826   return true;
827 }
828 
829 // Calculate GEP Stride and accumulated const ModOffset. Return Stride and
830 // ModOffset
831 static std::pair<APInt, APInt>
832 getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) {
833   unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
834   std::optional<APInt> Stride;
835   APInt ModOffset(BW, 0);
836   // Return a minimum gep stride, greatest common divisor of consective gep
837   // index scales(c.f. Bézout's identity).
838   while (auto *GEP = dyn_cast<GEPOperator>(PtrOp)) {
839     MapVector<Value *, APInt> VarOffsets;
840     if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset))
841       break;
842 
843     for (auto [V, Scale] : VarOffsets) {
844       // Only keep a power of two factor for non-inbounds
845       if (!GEP->isInBounds())
846         Scale = APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero());
847 
848       if (!Stride)
849         Stride = Scale;
850       else
851         Stride = APIntOps::GreatestCommonDivisor(*Stride, Scale);
852     }
853 
854     PtrOp = GEP->getPointerOperand();
855   }
856 
857   // Check whether pointer arrives back at Global Variable via at least one GEP.
858   // Even if it doesn't, we can check by alignment.
859   if (!isa<GlobalVariable>(PtrOp) || !Stride)
860     return {APInt(BW, 1), APInt(BW, 0)};
861 
862   // In consideration of signed GEP indices, non-negligible offset become
863   // remainder of division by minimum GEP stride.
864   ModOffset = ModOffset.srem(*Stride);
865   if (ModOffset.isNegative())
866     ModOffset += *Stride;
867 
868   return {*Stride, ModOffset};
869 }
870 
871 /// If C is a constant patterned array and all valid loaded results for given
872 /// alignment are same to a constant, return that constant.
873 static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
874   auto *LI = dyn_cast<LoadInst>(&I);
875   if (!LI || LI->isVolatile())
876     return false;
877 
878   // We can only fold the load if it is from a constant global with definitive
879   // initializer. Skip expensive logic if this is not the case.
880   auto *PtrOp = LI->getPointerOperand();
881   auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp));
882   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
883     return false;
884 
885   // Bail for large initializers in excess of 4K to avoid too many scans.
886   Constant *C = GV->getInitializer();
887   uint64_t GVSize = DL.getTypeAllocSize(C->getType());
888   if (!GVSize || 4096 < GVSize)
889     return false;
890 
891   Type *LoadTy = LI->getType();
892   unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
893   auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL);
894 
895   // Any possible offset could be multiple of GEP stride. And any valid
896   // offset is multiple of load alignment, so checking only multiples of bigger
897   // one is sufficient to say results' equality.
898   if (auto LA = LI->getAlign();
899       LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) {
900     ConstOffset = APInt(BW, 0);
901     Stride = APInt(BW, LA.value());
902   }
903 
904   Constant *Ca = ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL);
905   if (!Ca)
906     return false;
907 
908   unsigned E = GVSize - DL.getTypeStoreSize(LoadTy);
909   for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride)
910     if (Ca != ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL))
911       return false;
912 
913   I.replaceAllUsesWith(Ca);
914 
915   return true;
916 }
917 
918 /// This is the entry point for folds that could be implemented in regular
919 /// InstCombine, but they are separated because they are not expected to
920 /// occur frequently and/or have more than a constant-length pattern match.
921 static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
922                                 TargetTransformInfo &TTI,
923                                 TargetLibraryInfo &TLI, AliasAnalysis &AA,
924                                 AssumptionCache &AC) {
925   bool MadeChange = false;
926   for (BasicBlock &BB : F) {
927     // Ignore unreachable basic blocks.
928     if (!DT.isReachableFromEntry(&BB))
929       continue;
930 
931     const DataLayout &DL = F.getParent()->getDataLayout();
932 
933     // Walk the block backwards for efficiency. We're matching a chain of
934     // use->defs, so we're more likely to succeed by starting from the bottom.
935     // Also, we want to avoid matching partial patterns.
936     // TODO: It would be more efficient if we removed dead instructions
937     // iteratively in this loop rather than waiting until the end.
938     for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) {
939       MadeChange |= foldAnyOrAllBitsSet(I);
940       MadeChange |= foldGuardedFunnelShift(I, DT);
941       MadeChange |= tryToRecognizePopCount(I);
942       MadeChange |= tryToFPToSat(I, TTI);
943       MadeChange |= tryToRecognizeTableBasedCttz(I);
944       MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
945       MadeChange |= foldPatternedLoads(I, DL);
946       // NOTE: This function introduces erasing of the instruction `I`, so it
947       // needs to be called at the end of this sequence, otherwise we may make
948       // bugs.
949       MadeChange |= foldSqrt(I, TTI, TLI, AC, DT);
950     }
951   }
952 
953   // We're done with transforms, so remove dead instructions.
954   if (MadeChange)
955     for (BasicBlock &BB : F)
956       SimplifyInstructionsInBlock(&BB);
957 
958   return MadeChange;
959 }
960 
961 /// This is the entry point for all transforms. Pass manager differences are
962 /// handled in the callers of this function.
963 static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
964                     TargetLibraryInfo &TLI, DominatorTree &DT,
965                     AliasAnalysis &AA) {
966   bool MadeChange = false;
967   const DataLayout &DL = F.getParent()->getDataLayout();
968   TruncInstCombine TIC(AC, TLI, DL, DT);
969   MadeChange |= TIC.run(F);
970   MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC);
971   return MadeChange;
972 }
973 
974 PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
975                                                  FunctionAnalysisManager &AM) {
976   auto &AC = AM.getResult<AssumptionAnalysis>(F);
977   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
978   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
979   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
980   auto &AA = AM.getResult<AAManager>(F);
981   if (!runImpl(F, AC, TTI, TLI, DT, AA)) {
982     // No changes, all analyses are preserved.
983     return PreservedAnalyses::all();
984   }
985   // Mark all the analyses that instcombine updates as preserved.
986   PreservedAnalyses PA;
987   PA.preserveSet<CFGAnalyses>();
988   return PA;
989 }
990