1 //===- InstCombiner.h - InstCombine implementation --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file provides the interface for the instcombine pass implementation.
11 /// The interface is used for generic transformations in this folder and
12 /// target specific combinations in the targets.
13 /// The visitor implementation is in \c InstCombinerImpl in
14 /// \c InstCombineInternal.h.
15 ///
16 //===----------------------------------------------------------------------===//
17 
18 #ifndef LLVM_TRANSFORMS_INSTCOMBINE_INSTCOMBINER_H
19 #define LLVM_TRANSFORMS_INSTCOMBINE_INSTCOMBINER_H
20 
21 #include "llvm/Analysis/InstructionSimplify.h"
22 #include "llvm/Analysis/TargetFolder.h"
23 #include "llvm/Analysis/ValueTracking.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/PatternMatch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/KnownBits.h"
28 #include <cassert>
29 
30 #define DEBUG_TYPE "instcombine"
31 #include "llvm/Transforms/Utils/InstructionWorklist.h"
32 
33 namespace llvm {
34 
35 class AAResults;
36 class AssumptionCache;
37 class OptimizationRemarkEmitter;
38 class ProfileSummaryInfo;
39 class TargetLibraryInfo;
40 class TargetTransformInfo;
41 
42 /// The core instruction combiner logic.
43 ///
44 /// This class provides both the logic to recursively visit instructions and
45 /// combine them.
46 class LLVM_LIBRARY_VISIBILITY InstCombiner {
47   /// Only used to call target specific intrinsic combining.
48   /// It must **NOT** be used for any other purpose, as InstCombine is a
49   /// target-independent canonicalization transform.
50   TargetTransformInfo &TTI;
51 
52 public:
53   /// Maximum size of array considered when transforming.
54   uint64_t MaxArraySizeForCombine = 0;
55 
56   /// An IRBuilder that automatically inserts new instructions into the
57   /// worklist.
58   using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>;
59   BuilderTy &Builder;
60 
61 protected:
62   /// A worklist of the instructions that need to be simplified.
63   InstructionWorklist &Worklist;
64 
65   // Mode in which we are running the combiner.
66   const bool MinimizeSize;
67 
68   AAResults *AA;
69 
70   // Required analyses.
71   AssumptionCache &AC;
72   TargetLibraryInfo &TLI;
73   DominatorTree &DT;
74   const DataLayout &DL;
75   const SimplifyQuery SQ;
76   OptimizationRemarkEmitter &ORE;
77   BlockFrequencyInfo *BFI;
78   ProfileSummaryInfo *PSI;
79 
80   // Optional analyses. When non-null, these can both be used to do better
81   // combining and will be updated to reflect any changes.
82   LoopInfo *LI;
83 
84   bool MadeIRChange = false;
85 
86 public:
87   InstCombiner(InstructionWorklist &Worklist, BuilderTy &Builder,
88                bool MinimizeSize, AAResults *AA, AssumptionCache &AC,
89                TargetLibraryInfo &TLI, TargetTransformInfo &TTI,
90                DominatorTree &DT, OptimizationRemarkEmitter &ORE,
91                BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI,
92                const DataLayout &DL, LoopInfo *LI)
93       : TTI(TTI), Builder(Builder), Worklist(Worklist),
94         MinimizeSize(MinimizeSize), AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL),
95         SQ(DL, &TLI, &DT, &AC), ORE(ORE), BFI(BFI), PSI(PSI), LI(LI) {}
96 
97   virtual ~InstCombiner() = default;
98 
99   /// Return the source operand of a potentially bitcasted value while
100   /// optionally checking if it has one use. If there is no bitcast or the one
101   /// use check is not met, return the input value itself.
102   static Value *peekThroughBitcast(Value *V, bool OneUseOnly = false) {
103     if (auto *BitCast = dyn_cast<BitCastInst>(V))
104       if (!OneUseOnly || BitCast->hasOneUse())
105         return BitCast->getOperand(0);
106 
107     // V is not a bitcast or V has more than one use and OneUseOnly is true.
108     return V;
109   }
110 
111   /// Assign a complexity or rank value to LLVM Values. This is used to reduce
112   /// the amount of pattern matching needed for compares and commutative
113   /// instructions. For example, if we have:
114   ///   icmp ugt X, Constant
115   /// or
116   ///   xor (add X, Constant), cast Z
117   ///
118   /// We do not have to consider the commuted variants of these patterns because
119   /// canonicalization based on complexity guarantees the above ordering.
120   ///
121   /// This routine maps IR values to various complexity ranks:
122   ///   0 -> undef
123   ///   1 -> Constants
124   ///   2 -> Other non-instructions
125   ///   3 -> Arguments
126   ///   4 -> Cast and (f)neg/not instructions
127   ///   5 -> Other instructions
128   static unsigned getComplexity(Value *V) {
129     if (isa<Instruction>(V)) {
130       if (isa<CastInst>(V) || match(V, m_Neg(PatternMatch::m_Value())) ||
131           match(V, m_Not(PatternMatch::m_Value())) ||
132           match(V, m_FNeg(PatternMatch::m_Value())))
133         return 4;
134       return 5;
135     }
136     if (isa<Argument>(V))
137       return 3;
138     return isa<Constant>(V) ? (isa<UndefValue>(V) ? 0 : 1) : 2;
139   }
140 
141   /// Predicate canonicalization reduces the number of patterns that need to be
142   /// matched by other transforms. For example, we may swap the operands of a
143   /// conditional branch or select to create a compare with a canonical
144   /// (inverted) predicate which is then more likely to be matched with other
145   /// values.
146   static bool isCanonicalPredicate(CmpInst::Predicate Pred) {
147     switch (Pred) {
148     case CmpInst::ICMP_NE:
149     case CmpInst::ICMP_ULE:
150     case CmpInst::ICMP_SLE:
151     case CmpInst::ICMP_UGE:
152     case CmpInst::ICMP_SGE:
153     // TODO: There are 16 FCMP predicates. Should others be (not) canonical?
154     case CmpInst::FCMP_ONE:
155     case CmpInst::FCMP_OLE:
156     case CmpInst::FCMP_OGE:
157       return false;
158     default:
159       return true;
160     }
161   }
162 
163   /// Given an exploded icmp instruction, return true if the comparison only
164   /// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if
165   /// the result of the comparison is true when the input value is signed.
166   static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
167                              bool &TrueIfSigned) {
168     switch (Pred) {
169     case ICmpInst::ICMP_SLT: // True if LHS s< 0
170       TrueIfSigned = true;
171       return RHS.isZero();
172     case ICmpInst::ICMP_SLE: // True if LHS s<= -1
173       TrueIfSigned = true;
174       return RHS.isAllOnes();
175     case ICmpInst::ICMP_SGT: // True if LHS s> -1
176       TrueIfSigned = false;
177       return RHS.isAllOnes();
178     case ICmpInst::ICMP_SGE: // True if LHS s>= 0
179       TrueIfSigned = false;
180       return RHS.isZero();
181     case ICmpInst::ICMP_UGT:
182       // True if LHS u> RHS and RHS == sign-bit-mask - 1
183       TrueIfSigned = true;
184       return RHS.isMaxSignedValue();
185     case ICmpInst::ICMP_UGE:
186       // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
187       TrueIfSigned = true;
188       return RHS.isMinSignedValue();
189     case ICmpInst::ICMP_ULT:
190       // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
191       TrueIfSigned = false;
192       return RHS.isMinSignedValue();
193     case ICmpInst::ICMP_ULE:
194       // True if LHS u<= RHS and RHS == sign-bit-mask - 1
195       TrueIfSigned = false;
196       return RHS.isMaxSignedValue();
197     default:
198       return false;
199     }
200   }
201 
202   /// Add one to a Constant
203   static Constant *AddOne(Constant *C) {
204     return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1));
205   }
206 
207   /// Subtract one from a Constant
208   static Constant *SubOne(Constant *C) {
209     return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
210   }
211 
212   std::optional<std::pair<
213       CmpInst::Predicate,
214       Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
215                                                                        Predicate
216                                                                            Pred,
217                                                                    Constant *C);
218 
219   static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
220     // a ? b : false and a ? true : b are the canonical form of logical and/or.
221     // This includes !a ? b : false and !a ? true : b. Absorbing the not into
222     // the select by swapping operands would break recognition of this pattern
223     // in other analyses, so don't do that.
224     return match(&SI, PatternMatch::m_LogicalAnd(PatternMatch::m_Value(),
225                                                  PatternMatch::m_Value())) ||
226            match(&SI, PatternMatch::m_LogicalOr(PatternMatch::m_Value(),
227                                                 PatternMatch::m_Value()));
228   }
229 
230   /// Return true if the specified value is free to invert (apply ~ to).
231   /// This happens in cases where the ~ can be eliminated.  If WillInvertAllUses
232   /// is true, work under the assumption that the caller intends to remove all
233   /// uses of V and only keep uses of ~V.
234   ///
235   /// See also: canFreelyInvertAllUsersOf()
236   static bool isFreeToInvert(Value *V, bool WillInvertAllUses) {
237     // ~(~(X)) -> X.
238     if (match(V, m_Not(PatternMatch::m_Value())))
239       return true;
240 
241     // Constants can be considered to be not'ed values.
242     if (match(V, PatternMatch::m_AnyIntegralConstant()))
243       return true;
244 
245     // Compares can be inverted if all of their uses are being modified to use
246     // the ~V.
247     if (isa<CmpInst>(V))
248       return WillInvertAllUses;
249 
250     // If `V` is of the form `A + Constant` then `-1 - V` can be folded into
251     // `(-1 - Constant) - A` if we are willing to invert all of the uses.
252     if (match(V, m_Add(PatternMatch::m_Value(), PatternMatch::m_ImmConstant())))
253       return WillInvertAllUses;
254 
255     // If `V` is of the form `Constant - A` then `-1 - V` can be folded into
256     // `A + (-1 - Constant)` if we are willing to invert all of the uses.
257     if (match(V, m_Sub(PatternMatch::m_ImmConstant(), PatternMatch::m_Value())))
258       return WillInvertAllUses;
259 
260     // Selects with invertible operands are freely invertible
261     if (match(V,
262               m_Select(PatternMatch::m_Value(), m_Not(PatternMatch::m_Value()),
263                        m_Not(PatternMatch::m_Value()))))
264       return WillInvertAllUses;
265 
266     // Min/max may be in the form of intrinsics, so handle those identically
267     // to select patterns.
268     if (match(V, m_MaxOrMin(m_Not(PatternMatch::m_Value()),
269                             m_Not(PatternMatch::m_Value()))))
270       return WillInvertAllUses;
271 
272     return false;
273   }
274 
275   /// Given i1 V, can every user of V be freely adapted if V is changed to !V ?
276   /// InstCombine's freelyInvertAllUsersOf() must be kept in sync with this fn.
277   /// NOTE: for Instructions only!
278   ///
279   /// See also: isFreeToInvert()
280   static bool canFreelyInvertAllUsersOf(Instruction *V, Value *IgnoredUser) {
281     // Look at every user of V.
282     for (Use &U : V->uses()) {
283       if (U.getUser() == IgnoredUser)
284         continue; // Don't consider this user.
285 
286       auto *I = cast<Instruction>(U.getUser());
287       switch (I->getOpcode()) {
288       case Instruction::Select:
289         if (U.getOperandNo() != 0) // Only if the value is used as select cond.
290           return false;
291         if (shouldAvoidAbsorbingNotIntoSelect(*cast<SelectInst>(I)))
292           return false;
293         break;
294       case Instruction::Br:
295         assert(U.getOperandNo() == 0 && "Must be branching on that value.");
296         break; // Free to invert by swapping true/false values/destinations.
297       case Instruction::Xor: // Can invert 'xor' if it's a 'not', by ignoring
298                              // it.
299         if (!match(I, m_Not(PatternMatch::m_Value())))
300           return false; // Not a 'not'.
301         break;
302       default:
303         return false; // Don't know, likely not freely invertible.
304       }
305       // So far all users were free to invert...
306     }
307     return true; // Can freely invert all users!
308   }
309 
310   /// Some binary operators require special handling to avoid poison and
311   /// undefined behavior. If a constant vector has undef elements, replace those
312   /// undefs with identity constants if possible because those are always safe
313   /// to execute. If no identity constant exists, replace undef with some other
314   /// safe constant.
315   static Constant *
316   getSafeVectorConstantForBinop(BinaryOperator::BinaryOps Opcode, Constant *In,
317                                 bool IsRHSConstant) {
318     auto *InVTy = cast<FixedVectorType>(In->getType());
319 
320     Type *EltTy = InVTy->getElementType();
321     auto *SafeC = ConstantExpr::getBinOpIdentity(Opcode, EltTy, IsRHSConstant);
322     if (!SafeC) {
323       // TODO: Should this be available as a constant utility function? It is
324       // similar to getBinOpAbsorber().
325       if (IsRHSConstant) {
326         switch (Opcode) {
327         case Instruction::SRem: // X % 1 = 0
328         case Instruction::URem: // X %u 1 = 0
329           SafeC = ConstantInt::get(EltTy, 1);
330           break;
331         case Instruction::FRem: // X % 1.0 (doesn't simplify, but it is safe)
332           SafeC = ConstantFP::get(EltTy, 1.0);
333           break;
334         default:
335           llvm_unreachable(
336               "Only rem opcodes have no identity constant for RHS");
337         }
338       } else {
339         switch (Opcode) {
340         case Instruction::Shl:  // 0 << X = 0
341         case Instruction::LShr: // 0 >>u X = 0
342         case Instruction::AShr: // 0 >> X = 0
343         case Instruction::SDiv: // 0 / X = 0
344         case Instruction::UDiv: // 0 /u X = 0
345         case Instruction::SRem: // 0 % X = 0
346         case Instruction::URem: // 0 %u X = 0
347         case Instruction::Sub:  // 0 - X (doesn't simplify, but it is safe)
348         case Instruction::FSub: // 0.0 - X (doesn't simplify, but it is safe)
349         case Instruction::FDiv: // 0.0 / X (doesn't simplify, but it is safe)
350         case Instruction::FRem: // 0.0 % X = 0
351           SafeC = Constant::getNullValue(EltTy);
352           break;
353         default:
354           llvm_unreachable("Expected to find identity constant for opcode");
355         }
356       }
357     }
358     assert(SafeC && "Must have safe constant for binop");
359     unsigned NumElts = InVTy->getNumElements();
360     SmallVector<Constant *, 16> Out(NumElts);
361     for (unsigned i = 0; i != NumElts; ++i) {
362       Constant *C = In->getAggregateElement(i);
363       Out[i] = isa<UndefValue>(C) ? SafeC : C;
364     }
365     return ConstantVector::get(Out);
366   }
367 
368   void addToWorklist(Instruction *I) { Worklist.push(I); }
369 
370   AssumptionCache &getAssumptionCache() const { return AC; }
371   TargetLibraryInfo &getTargetLibraryInfo() const { return TLI; }
372   DominatorTree &getDominatorTree() const { return DT; }
373   const DataLayout &getDataLayout() const { return DL; }
374   const SimplifyQuery &getSimplifyQuery() const { return SQ; }
375   OptimizationRemarkEmitter &getOptimizationRemarkEmitter() const {
376     return ORE;
377   }
378   BlockFrequencyInfo *getBlockFrequencyInfo() const { return BFI; }
379   ProfileSummaryInfo *getProfileSummaryInfo() const { return PSI; }
380   LoopInfo *getLoopInfo() const { return LI; }
381 
382   // Call target specific combiners
383   std::optional<Instruction *> targetInstCombineIntrinsic(IntrinsicInst &II);
384   std::optional<Value *>
385   targetSimplifyDemandedUseBitsIntrinsic(IntrinsicInst &II, APInt DemandedMask,
386                                          KnownBits &Known,
387                                          bool &KnownBitsComputed);
388   std::optional<Value *> targetSimplifyDemandedVectorEltsIntrinsic(
389       IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts,
390       APInt &UndefElts2, APInt &UndefElts3,
391       std::function<void(Instruction *, unsigned, APInt, APInt &)>
392           SimplifyAndSetOp);
393 
394   /// Inserts an instruction \p New before instruction \p Old
395   ///
396   /// Also adds the new instruction to the worklist and returns \p New so that
397   /// it is suitable for use as the return from the visitation patterns.
398   Instruction *InsertNewInstBefore(Instruction *New, Instruction &Old) {
399     assert(New && !New->getParent() &&
400            "New instruction already inserted into a basic block!");
401     BasicBlock *BB = Old.getParent();
402     New->insertInto(BB, Old.getIterator()); // Insert inst
403     Worklist.add(New);
404     return New;
405   }
406 
407   /// Same as InsertNewInstBefore, but also sets the debug loc.
408   Instruction *InsertNewInstWith(Instruction *New, Instruction &Old) {
409     New->setDebugLoc(Old.getDebugLoc());
410     return InsertNewInstBefore(New, Old);
411   }
412 
413   /// A combiner-aware RAUW-like routine.
414   ///
415   /// This method is to be used when an instruction is found to be dead,
416   /// replaceable with another preexisting expression. Here we add all uses of
417   /// I to the worklist, replace all uses of I with the new value, then return
418   /// I, so that the inst combiner will know that I was modified.
419   Instruction *replaceInstUsesWith(Instruction &I, Value *V) {
420     // If there are no uses to replace, then we return nullptr to indicate that
421     // no changes were made to the program.
422     if (I.use_empty()) return nullptr;
423 
424     Worklist.pushUsersToWorkList(I); // Add all modified instrs to worklist.
425 
426     // If we are replacing the instruction with itself, this must be in a
427     // segment of unreachable code, so just clobber the instruction.
428     if (&I == V)
429       V = PoisonValue::get(I.getType());
430 
431     LLVM_DEBUG(dbgs() << "IC: Replacing " << I << "\n"
432                       << "    with " << *V << '\n');
433 
434     // If V is a new unnamed instruction, take the name from the old one.
435     if (V->use_empty() && isa<Instruction>(V) && !V->hasName() && I.hasName())
436       V->takeName(&I);
437 
438     I.replaceAllUsesWith(V);
439     return &I;
440   }
441 
442   /// Replace operand of instruction and add old operand to the worklist.
443   Instruction *replaceOperand(Instruction &I, unsigned OpNum, Value *V) {
444     Value *OldOp = I.getOperand(OpNum);
445     I.setOperand(OpNum, V);
446     Worklist.handleUseCountDecrement(OldOp);
447     return &I;
448   }
449 
450   /// Replace use and add the previously used value to the worklist.
451   void replaceUse(Use &U, Value *NewValue) {
452     Value *OldOp = U;
453     U = NewValue;
454     Worklist.handleUseCountDecrement(OldOp);
455   }
456 
457   /// Combiner aware instruction erasure.
458   ///
459   /// When dealing with an instruction that has side effects or produces a void
460   /// value, we can't rely on DCE to delete the instruction. Instead, visit
461   /// methods should return the value returned by this function.
462   virtual Instruction *eraseInstFromFunction(Instruction &I) = 0;
463 
464   void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
465                         const Instruction *CxtI) const {
466     llvm::computeKnownBits(V, Known, DL, Depth, &AC, CxtI, &DT);
467   }
468 
469   KnownBits computeKnownBits(const Value *V, unsigned Depth,
470                              const Instruction *CxtI) const {
471     return llvm::computeKnownBits(V, DL, Depth, &AC, CxtI, &DT);
472   }
473 
474   bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false,
475                               unsigned Depth = 0,
476                               const Instruction *CxtI = nullptr) {
477     return llvm::isKnownToBeAPowerOfTwo(V, DL, OrZero, Depth, &AC, CxtI, &DT);
478   }
479 
480   bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth = 0,
481                          const Instruction *CxtI = nullptr) const {
482     return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT);
483   }
484 
485   unsigned ComputeNumSignBits(const Value *Op, unsigned Depth = 0,
486                               const Instruction *CxtI = nullptr) const {
487     return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT);
488   }
489 
490   unsigned ComputeMaxSignificantBits(const Value *Op, unsigned Depth = 0,
491                                      const Instruction *CxtI = nullptr) const {
492     return llvm::ComputeMaxSignificantBits(Op, DL, Depth, &AC, CxtI, &DT);
493   }
494 
495   OverflowResult computeOverflowForUnsignedMul(const Value *LHS,
496                                                const Value *RHS,
497                                                const Instruction *CxtI) const {
498     return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
499   }
500 
501   OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
502                                              const Instruction *CxtI) const {
503     return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
504   }
505 
506   OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
507                                                const Value *RHS,
508                                                const Instruction *CxtI) const {
509     return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
510   }
511 
512   OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
513                                              const Instruction *CxtI) const {
514     return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
515   }
516 
517   OverflowResult computeOverflowForUnsignedSub(const Value *LHS,
518                                                const Value *RHS,
519                                                const Instruction *CxtI) const {
520     return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
521   }
522 
523   OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS,
524                                              const Instruction *CxtI) const {
525     return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
526   }
527 
528   virtual bool SimplifyDemandedBits(Instruction *I, unsigned OpNo,
529                                     const APInt &DemandedMask, KnownBits &Known,
530                                     unsigned Depth = 0) = 0;
531   virtual Value *
532   SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts,
533                              unsigned Depth = 0,
534                              bool AllowMultipleUsers = false) = 0;
535 
536   bool isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const;
537 };
538 
539 } // namespace llvm
540 
541 #undef DEBUG_TYPE
542 
543 #endif
544