1 //===- llvm/Support/KnownBits.h - Stores known zeros/ones -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_SUPPORT_KNOWNBITS_H
15 #define LLVM_SUPPORT_KNOWNBITS_H
16 
17 #include "llvm/ADT/APInt.h"
18 #include <optional>
19 
20 namespace llvm {
21 
22 // Struct for tracking the known zeros and ones of a value.
23 struct KnownBits {
24   APInt Zero;
25   APInt One;
26 
27 private:
28   // Internal constructor for creating a KnownBits from two APInts.
KnownBitsKnownBits29   KnownBits(APInt Zero, APInt One)
30       : Zero(std::move(Zero)), One(std::move(One)) {}
31 
32 public:
33   // Default construct Zero and One.
34   KnownBits() = default;
35 
36   /// Create a known bits object of BitWidth bits initialized to unknown.
KnownBitsKnownBits37   KnownBits(unsigned BitWidth) : Zero(BitWidth, 0), One(BitWidth, 0) {}
38 
39   /// Get the bit width of this value.
getBitWidthKnownBits40   unsigned getBitWidth() const {
41     assert(Zero.getBitWidth() == One.getBitWidth() &&
42            "Zero and One should have the same width!");
43     return Zero.getBitWidth();
44   }
45 
46   /// Returns true if there is conflicting information.
hasConflictKnownBits47   bool hasConflict() const { return Zero.intersects(One); }
48 
49   /// Returns true if we know the value of all bits.
isConstantKnownBits50   bool isConstant() const {
51     assert(!hasConflict() && "KnownBits conflict!");
52     return Zero.popcount() + One.popcount() == getBitWidth();
53   }
54 
55   /// Returns the value when all bits have a known value. This just returns One
56   /// with a protective assertion.
getConstantKnownBits57   const APInt &getConstant() const {
58     assert(isConstant() && "Can only get value when all bits are known");
59     return One;
60   }
61 
62   /// Returns true if we don't know any bits.
isUnknownKnownBits63   bool isUnknown() const { return Zero.isZero() && One.isZero(); }
64 
65   /// Resets the known state of all bits.
resetAllKnownBits66   void resetAll() {
67     Zero.clearAllBits();
68     One.clearAllBits();
69   }
70 
71   /// Returns true if value is all zero.
isZeroKnownBits72   bool isZero() const {
73     assert(!hasConflict() && "KnownBits conflict!");
74     return Zero.isAllOnes();
75   }
76 
77   /// Returns true if value is all one bits.
isAllOnesKnownBits78   bool isAllOnes() const {
79     assert(!hasConflict() && "KnownBits conflict!");
80     return One.isAllOnes();
81   }
82 
83   /// Make all bits known to be zero and discard any previous information.
setAllZeroKnownBits84   void setAllZero() {
85     Zero.setAllBits();
86     One.clearAllBits();
87   }
88 
89   /// Make all bits known to be one and discard any previous information.
setAllOnesKnownBits90   void setAllOnes() {
91     Zero.clearAllBits();
92     One.setAllBits();
93   }
94 
95   /// Returns true if this value is known to be negative.
isNegativeKnownBits96   bool isNegative() const { return One.isSignBitSet(); }
97 
98   /// Returns true if this value is known to be non-negative.
isNonNegativeKnownBits99   bool isNonNegative() const { return Zero.isSignBitSet(); }
100 
101   /// Returns true if this value is known to be non-zero.
isNonZeroKnownBits102   bool isNonZero() const { return !One.isZero(); }
103 
104   /// Returns true if this value is known to be positive.
isStrictlyPositiveKnownBits105   bool isStrictlyPositive() const {
106     return Zero.isSignBitSet() && !One.isZero();
107   }
108 
109   /// Make this value negative.
makeNegativeKnownBits110   void makeNegative() {
111     One.setSignBit();
112   }
113 
114   /// Make this value non-negative.
makeNonNegativeKnownBits115   void makeNonNegative() {
116     Zero.setSignBit();
117   }
118 
119   /// Return the minimal unsigned value possible given these KnownBits.
getMinValueKnownBits120   APInt getMinValue() const {
121     // Assume that all bits that aren't known-ones are zeros.
122     return One;
123   }
124 
125   /// Return the minimal signed value possible given these KnownBits.
getSignedMinValueKnownBits126   APInt getSignedMinValue() const {
127     // Assume that all bits that aren't known-ones are zeros.
128     APInt Min = One;
129     // Sign bit is unknown.
130     if (Zero.isSignBitClear())
131       Min.setSignBit();
132     return Min;
133   }
134 
135   /// Return the maximal unsigned value possible given these KnownBits.
getMaxValueKnownBits136   APInt getMaxValue() const {
137     // Assume that all bits that aren't known-zeros are ones.
138     return ~Zero;
139   }
140 
141   /// Return the maximal signed value possible given these KnownBits.
getSignedMaxValueKnownBits142   APInt getSignedMaxValue() const {
143     // Assume that all bits that aren't known-zeros are ones.
144     APInt Max = ~Zero;
145     // Sign bit is unknown.
146     if (One.isSignBitClear())
147       Max.clearSignBit();
148     return Max;
149   }
150 
151   /// Return known bits for a truncation of the value we're tracking.
truncKnownBits152   KnownBits trunc(unsigned BitWidth) const {
153     return KnownBits(Zero.trunc(BitWidth), One.trunc(BitWidth));
154   }
155 
156   /// Return known bits for an "any" extension of the value we're tracking,
157   /// where we don't know anything about the extended bits.
anyextKnownBits158   KnownBits anyext(unsigned BitWidth) const {
159     return KnownBits(Zero.zext(BitWidth), One.zext(BitWidth));
160   }
161 
162   /// Return known bits for a zero extension of the value we're tracking.
zextKnownBits163   KnownBits zext(unsigned BitWidth) const {
164     unsigned OldBitWidth = getBitWidth();
165     APInt NewZero = Zero.zext(BitWidth);
166     NewZero.setBitsFrom(OldBitWidth);
167     return KnownBits(NewZero, One.zext(BitWidth));
168   }
169 
170   /// Return known bits for a sign extension of the value we're tracking.
sextKnownBits171   KnownBits sext(unsigned BitWidth) const {
172     return KnownBits(Zero.sext(BitWidth), One.sext(BitWidth));
173   }
174 
175   /// Return known bits for an "any" extension or truncation of the value we're
176   /// tracking.
anyextOrTruncKnownBits177   KnownBits anyextOrTrunc(unsigned BitWidth) const {
178     if (BitWidth > getBitWidth())
179       return anyext(BitWidth);
180     if (BitWidth < getBitWidth())
181       return trunc(BitWidth);
182     return *this;
183   }
184 
185   /// Return known bits for a zero extension or truncation of the value we're
186   /// tracking.
zextOrTruncKnownBits187   KnownBits zextOrTrunc(unsigned BitWidth) const {
188     if (BitWidth > getBitWidth())
189       return zext(BitWidth);
190     if (BitWidth < getBitWidth())
191       return trunc(BitWidth);
192     return *this;
193   }
194 
195   /// Return known bits for a sign extension or truncation of the value we're
196   /// tracking.
sextOrTruncKnownBits197   KnownBits sextOrTrunc(unsigned BitWidth) const {
198     if (BitWidth > getBitWidth())
199       return sext(BitWidth);
200     if (BitWidth < getBitWidth())
201       return trunc(BitWidth);
202     return *this;
203   }
204 
205   /// Return known bits for a in-register sign extension of the value we're
206   /// tracking.
207   KnownBits sextInReg(unsigned SrcBitWidth) const;
208 
209   /// Insert the bits from a smaller known bits starting at bitPosition.
insertBitsKnownBits210   void insertBits(const KnownBits &SubBits, unsigned BitPosition) {
211     Zero.insertBits(SubBits.Zero, BitPosition);
212     One.insertBits(SubBits.One, BitPosition);
213   }
214 
215   /// Return a subset of the known bits from [bitPosition,bitPosition+numBits).
extractBitsKnownBits216   KnownBits extractBits(unsigned NumBits, unsigned BitPosition) const {
217     return KnownBits(Zero.extractBits(NumBits, BitPosition),
218                      One.extractBits(NumBits, BitPosition));
219   }
220 
221   /// Concatenate the bits from \p Lo onto the bottom of *this.  This is
222   /// equivalent to:
223   ///   (this->zext(NewWidth) << Lo.getBitWidth()) | Lo.zext(NewWidth)
concatKnownBits224   KnownBits concat(const KnownBits &Lo) const {
225     return KnownBits(Zero.concat(Lo.Zero), One.concat(Lo.One));
226   }
227 
228   /// Return KnownBits based on this, but updated given that the underlying
229   /// value is known to be greater than or equal to Val.
230   KnownBits makeGE(const APInt &Val) const;
231 
232   /// Returns the minimum number of trailing zero bits.
countMinTrailingZerosKnownBits233   unsigned countMinTrailingZeros() const { return Zero.countr_one(); }
234 
235   /// Returns the minimum number of trailing one bits.
countMinTrailingOnesKnownBits236   unsigned countMinTrailingOnes() const { return One.countr_one(); }
237 
238   /// Returns the minimum number of leading zero bits.
countMinLeadingZerosKnownBits239   unsigned countMinLeadingZeros() const { return Zero.countl_one(); }
240 
241   /// Returns the minimum number of leading one bits.
countMinLeadingOnesKnownBits242   unsigned countMinLeadingOnes() const { return One.countl_one(); }
243 
244   /// Returns the number of times the sign bit is replicated into the other
245   /// bits.
countMinSignBitsKnownBits246   unsigned countMinSignBits() const {
247     if (isNonNegative())
248       return countMinLeadingZeros();
249     if (isNegative())
250       return countMinLeadingOnes();
251     // Every value has at least 1 sign bit.
252     return 1;
253   }
254 
255   /// Returns the maximum number of bits needed to represent all possible
256   /// signed values with these known bits. This is the inverse of the minimum
257   /// number of known sign bits. Examples for bitwidth 5:
258   /// 110?? --> 4
259   /// 0000? --> 2
countMaxSignificantBitsKnownBits260   unsigned countMaxSignificantBits() const {
261     return getBitWidth() - countMinSignBits() + 1;
262   }
263 
264   /// Returns the maximum number of trailing zero bits possible.
countMaxTrailingZerosKnownBits265   unsigned countMaxTrailingZeros() const { return One.countr_zero(); }
266 
267   /// Returns the maximum number of trailing one bits possible.
countMaxTrailingOnesKnownBits268   unsigned countMaxTrailingOnes() const { return Zero.countr_zero(); }
269 
270   /// Returns the maximum number of leading zero bits possible.
countMaxLeadingZerosKnownBits271   unsigned countMaxLeadingZeros() const { return One.countl_zero(); }
272 
273   /// Returns the maximum number of leading one bits possible.
countMaxLeadingOnesKnownBits274   unsigned countMaxLeadingOnes() const { return Zero.countl_zero(); }
275 
276   /// Returns the number of bits known to be one.
countMinPopulationKnownBits277   unsigned countMinPopulation() const { return One.popcount(); }
278 
279   /// Returns the maximum number of bits that could be one.
countMaxPopulationKnownBits280   unsigned countMaxPopulation() const {
281     return getBitWidth() - Zero.popcount();
282   }
283 
284   /// Returns the maximum number of bits needed to represent all possible
285   /// unsigned values with these known bits. This is the inverse of the
286   /// minimum number of leading zeros.
countMaxActiveBitsKnownBits287   unsigned countMaxActiveBits() const {
288     return getBitWidth() - countMinLeadingZeros();
289   }
290 
291   /// Create known bits from a known constant.
makeConstantKnownBits292   static KnownBits makeConstant(const APInt &C) {
293     return KnownBits(~C, C);
294   }
295 
296   /// Returns KnownBits information that is known to be true for both this and
297   /// RHS.
298   ///
299   /// When an operation is known to return one of its operands, this can be used
300   /// to combine information about the known bits of the operands to get the
301   /// information that must be true about the result.
intersectWithKnownBits302   KnownBits intersectWith(const KnownBits &RHS) const {
303     return KnownBits(Zero & RHS.Zero, One & RHS.One);
304   }
305 
306   /// Returns KnownBits information that is known to be true for either this or
307   /// RHS or both.
308   ///
309   /// This can be used to combine different sources of information about the
310   /// known bits of a single value, e.g. information about the low bits and the
311   /// high bits of the result of a multiplication.
unionWithKnownBits312   KnownBits unionWith(const KnownBits &RHS) const {
313     return KnownBits(Zero | RHS.Zero, One | RHS.One);
314   }
315 
316   /// Compute known bits common to LHS and RHS.
317   LLVM_DEPRECATED("use intersectWith instead", "intersectWith")
commonBitsKnownBits318   static KnownBits commonBits(const KnownBits &LHS, const KnownBits &RHS) {
319     return LHS.intersectWith(RHS);
320   }
321 
322   /// Return true if LHS and RHS have no common bits set.
haveNoCommonBitsSetKnownBits323   static bool haveNoCommonBitsSet(const KnownBits &LHS, const KnownBits &RHS) {
324     return (LHS.Zero | RHS.Zero).isAllOnes();
325   }
326 
327   /// Compute known bits resulting from adding LHS, RHS and a 1-bit Carry.
328   static KnownBits computeForAddCarry(
329       const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry);
330 
331   /// Compute known bits resulting from adding LHS and RHS.
332   static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
333                                     KnownBits RHS);
334 
335   /// Compute known bits results from subtracting RHS from LHS with 1-bit
336   /// Borrow.
337   static KnownBits computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
338                                        const KnownBits &Borrow);
339 
340   /// Compute knownbits resulting from llvm.sadd.sat(LHS, RHS)
341   static KnownBits sadd_sat(const KnownBits &LHS, const KnownBits &RHS);
342 
343   /// Compute knownbits resulting from llvm.uadd.sat(LHS, RHS)
344   static KnownBits uadd_sat(const KnownBits &LHS, const KnownBits &RHS);
345 
346   /// Compute knownbits resulting from llvm.ssub.sat(LHS, RHS)
347   static KnownBits ssub_sat(const KnownBits &LHS, const KnownBits &RHS);
348 
349   /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS)
350   static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS);
351 
352   /// Compute known bits resulting from multiplying LHS and RHS.
353   static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS,
354                        bool NoUndefSelfMultiply = false);
355 
356   /// Compute known bits from sign-extended multiply-hi.
357   static KnownBits mulhs(const KnownBits &LHS, const KnownBits &RHS);
358 
359   /// Compute known bits from zero-extended multiply-hi.
360   static KnownBits mulhu(const KnownBits &LHS, const KnownBits &RHS);
361 
362   /// Compute known bits for sdiv(LHS, RHS).
363   static KnownBits sdiv(const KnownBits &LHS, const KnownBits &RHS,
364                         bool Exact = false);
365 
366   /// Compute known bits for udiv(LHS, RHS).
367   static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS,
368                         bool Exact = false);
369 
370   /// Compute known bits for urem(LHS, RHS).
371   static KnownBits urem(const KnownBits &LHS, const KnownBits &RHS);
372 
373   /// Compute known bits for srem(LHS, RHS).
374   static KnownBits srem(const KnownBits &LHS, const KnownBits &RHS);
375 
376   /// Compute known bits for umax(LHS, RHS).
377   static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS);
378 
379   /// Compute known bits for umin(LHS, RHS).
380   static KnownBits umin(const KnownBits &LHS, const KnownBits &RHS);
381 
382   /// Compute known bits for smax(LHS, RHS).
383   static KnownBits smax(const KnownBits &LHS, const KnownBits &RHS);
384 
385   /// Compute known bits for smin(LHS, RHS).
386   static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS);
387 
388   /// Compute known bits for shl(LHS, RHS).
389   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
390   static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS,
391                        bool NUW = false, bool NSW = false,
392                        bool ShAmtNonZero = false);
393 
394   /// Compute known bits for lshr(LHS, RHS).
395   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
396   static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
397                         bool ShAmtNonZero = false);
398 
399   /// Compute known bits for ashr(LHS, RHS).
400   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
401   static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
402                         bool ShAmtNonZero = false);
403 
404   /// Determine if these known bits always give the same ICMP_EQ result.
405   static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
406 
407   /// Determine if these known bits always give the same ICMP_NE result.
408   static std::optional<bool> ne(const KnownBits &LHS, const KnownBits &RHS);
409 
410   /// Determine if these known bits always give the same ICMP_UGT result.
411   static std::optional<bool> ugt(const KnownBits &LHS, const KnownBits &RHS);
412 
413   /// Determine if these known bits always give the same ICMP_UGE result.
414   static std::optional<bool> uge(const KnownBits &LHS, const KnownBits &RHS);
415 
416   /// Determine if these known bits always give the same ICMP_ULT result.
417   static std::optional<bool> ult(const KnownBits &LHS, const KnownBits &RHS);
418 
419   /// Determine if these known bits always give the same ICMP_ULE result.
420   static std::optional<bool> ule(const KnownBits &LHS, const KnownBits &RHS);
421 
422   /// Determine if these known bits always give the same ICMP_SGT result.
423   static std::optional<bool> sgt(const KnownBits &LHS, const KnownBits &RHS);
424 
425   /// Determine if these known bits always give the same ICMP_SGE result.
426   static std::optional<bool> sge(const KnownBits &LHS, const KnownBits &RHS);
427 
428   /// Determine if these known bits always give the same ICMP_SLT result.
429   static std::optional<bool> slt(const KnownBits &LHS, const KnownBits &RHS);
430 
431   /// Determine if these known bits always give the same ICMP_SLE result.
432   static std::optional<bool> sle(const KnownBits &LHS, const KnownBits &RHS);
433 
434   /// Update known bits based on ANDing with RHS.
435   KnownBits &operator&=(const KnownBits &RHS);
436 
437   /// Update known bits based on ORing with RHS.
438   KnownBits &operator|=(const KnownBits &RHS);
439 
440   /// Update known bits based on XORing with RHS.
441   KnownBits &operator^=(const KnownBits &RHS);
442 
443   /// Compute known bits for the absolute value.
444   KnownBits abs(bool IntMinIsPoison = false) const;
445 
byteSwapKnownBits446   KnownBits byteSwap() const {
447     return KnownBits(Zero.byteSwap(), One.byteSwap());
448   }
449 
reverseBitsKnownBits450   KnownBits reverseBits() const {
451     return KnownBits(Zero.reverseBits(), One.reverseBits());
452   }
453 
454   /// Compute known bits for X & -X, which has only the lowest bit set of X set.
455   /// The name comes from the X86 BMI instruction
456   KnownBits blsi() const;
457 
458   /// Compute known bits for X ^ (X - 1), which has all bits up to and including
459   /// the lowest set bit of X set. The name comes from the X86 BMI instruction.
460   KnownBits blsmsk() const;
461 
462   bool operator==(const KnownBits &Other) const {
463     return Zero == Other.Zero && One == Other.One;
464   }
465 
466   bool operator!=(const KnownBits &Other) const { return !(*this == Other); }
467 
468   void print(raw_ostream &OS) const;
469   void dump() const;
470 
471 private:
472   // Internal helper for getting the initial KnownBits for an `srem` or `urem`
473   // operation with the low-bits set.
474   static KnownBits remGetLowBits(const KnownBits &LHS, const KnownBits &RHS);
475 };
476 
477 inline KnownBits operator&(KnownBits LHS, const KnownBits &RHS) {
478   LHS &= RHS;
479   return LHS;
480 }
481 
482 inline KnownBits operator&(const KnownBits &LHS, KnownBits &&RHS) {
483   RHS &= LHS;
484   return std::move(RHS);
485 }
486 
487 inline KnownBits operator|(KnownBits LHS, const KnownBits &RHS) {
488   LHS |= RHS;
489   return LHS;
490 }
491 
492 inline KnownBits operator|(const KnownBits &LHS, KnownBits &&RHS) {
493   RHS |= LHS;
494   return std::move(RHS);
495 }
496 
497 inline KnownBits operator^(KnownBits LHS, const KnownBits &RHS) {
498   LHS ^= RHS;
499   return LHS;
500 }
501 
502 inline KnownBits operator^(const KnownBits &LHS, KnownBits &&RHS) {
503   RHS ^= LHS;
504   return std::move(RHS);
505 }
506 
507 inline raw_ostream &operator<<(raw_ostream &OS, const KnownBits &Known) {
508   Known.print(OS);
509   return OS;
510 }
511 
512 } // end namespace llvm
513 
514 #endif
515