1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/KnownBits.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <cassert>
18 
19 using namespace llvm;
20 
computeForAddCarry(const KnownBits & LHS,const KnownBits & RHS,bool CarryZero,bool CarryOne)21 static KnownBits computeForAddCarry(
22     const KnownBits &LHS, const KnownBits &RHS,
23     bool CarryZero, bool CarryOne) {
24   assert(!(CarryZero && CarryOne) &&
25          "Carry can't be zero and one at the same time");
26 
27   APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
28   APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
29 
30   // Compute known bits of the carry.
31   APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
32   APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
33 
34   // Compute set of known bits (where all three relevant bits are known).
35   APInt LHSKnownUnion = LHS.Zero | LHS.One;
36   APInt RHSKnownUnion = RHS.Zero | RHS.One;
37   APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
38   APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
39 
40   assert((PossibleSumZero & Known) == (PossibleSumOne & Known) &&
41          "known bits of sum differ");
42 
43   // Compute known bits of the result.
44   KnownBits KnownOut;
45   KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
46   KnownOut.One = std::move(PossibleSumOne) & Known;
47   return KnownOut;
48 }
49 
computeForAddCarry(const KnownBits & LHS,const KnownBits & RHS,const KnownBits & Carry)50 KnownBits KnownBits::computeForAddCarry(
51     const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
52   assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
53   return ::computeForAddCarry(
54       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
55 }
56 
computeForAddSub(bool Add,bool NSW,const KnownBits & LHS,KnownBits RHS)57 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
58                                       const KnownBits &LHS, KnownBits RHS) {
59   KnownBits KnownOut;
60   if (Add) {
61     // Sum = LHS + RHS + 0
62     KnownOut = ::computeForAddCarry(
63         LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
64   } else {
65     // Sum = LHS + ~RHS + 1
66     std::swap(RHS.Zero, RHS.One);
67     KnownOut = ::computeForAddCarry(
68         LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
69   }
70 
71   // Are we still trying to solve for the sign bit?
72   if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
73     if (NSW) {
74       // Adding two non-negative numbers, or subtracting a negative number from
75       // a non-negative one, can't wrap into negative.
76       if (LHS.isNonNegative() && RHS.isNonNegative())
77         KnownOut.makeNonNegative();
78       // Adding two negative numbers, or subtracting a non-negative number from
79       // a negative one, can't wrap into non-negative.
80       else if (LHS.isNegative() && RHS.isNegative())
81         KnownOut.makeNegative();
82     }
83   }
84 
85   return KnownOut;
86 }
87 
computeForSubBorrow(const KnownBits & LHS,KnownBits RHS,const KnownBits & Borrow)88 KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
89                                          const KnownBits &Borrow) {
90   assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit");
91 
92   // LHS - RHS = LHS + ~RHS + 1
93   // Carry 1 - Borrow in ::computeForAddCarry
94   std::swap(RHS.Zero, RHS.One);
95   return ::computeForAddCarry(LHS, RHS,
96                               /*CarryZero=*/Borrow.One.getBoolValue(),
97                               /*CarryOne=*/Borrow.Zero.getBoolValue());
98 }
99 
sextInReg(unsigned SrcBitWidth) const100 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
101   unsigned BitWidth = getBitWidth();
102   assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
103          "Illegal sext-in-register");
104 
105   if (SrcBitWidth == BitWidth)
106     return *this;
107 
108   unsigned ExtBits = BitWidth - SrcBitWidth;
109   KnownBits Result;
110   Result.One = One << ExtBits;
111   Result.Zero = Zero << ExtBits;
112   Result.One.ashrInPlace(ExtBits);
113   Result.Zero.ashrInPlace(ExtBits);
114   return Result;
115 }
116 
makeGE(const APInt & Val) const117 KnownBits KnownBits::makeGE(const APInt &Val) const {
118   // Count the number of leading bit positions where our underlying value is
119   // known to be less than or equal to Val.
120   unsigned N = (Zero | Val).countl_one();
121 
122   // For each of those bit positions, if Val has a 1 in that bit then our
123   // underlying value must also have a 1.
124   APInt MaskedVal(Val);
125   MaskedVal.clearLowBits(getBitWidth() - N);
126   return KnownBits(Zero, One | MaskedVal);
127 }
128 
umax(const KnownBits & LHS,const KnownBits & RHS)129 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
130   // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
131   // RHS. Ideally our caller would already have spotted these cases and
132   // optimized away the umax operation, but we handle them here for
133   // completeness.
134   if (LHS.getMinValue().uge(RHS.getMaxValue()))
135     return LHS;
136   if (RHS.getMinValue().uge(LHS.getMaxValue()))
137     return RHS;
138 
139   // If the result of the umax is LHS then it must be greater than or equal to
140   // the minimum possible value of RHS. Likewise for RHS. Any known bits that
141   // are common to these two values are also known in the result.
142   KnownBits L = LHS.makeGE(RHS.getMinValue());
143   KnownBits R = RHS.makeGE(LHS.getMinValue());
144   return L.intersectWith(R);
145 }
146 
umin(const KnownBits & LHS,const KnownBits & RHS)147 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
148   // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
149   auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
150   return Flip(umax(Flip(LHS), Flip(RHS)));
151 }
152 
smax(const KnownBits & LHS,const KnownBits & RHS)153 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
154   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
155   auto Flip = [](const KnownBits &Val) {
156     unsigned SignBitPosition = Val.getBitWidth() - 1;
157     APInt Zero = Val.Zero;
158     APInt One = Val.One;
159     Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
160     One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
161     return KnownBits(Zero, One);
162   };
163   return Flip(umax(Flip(LHS), Flip(RHS)));
164 }
165 
smin(const KnownBits & LHS,const KnownBits & RHS)166 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
167   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
168   auto Flip = [](const KnownBits &Val) {
169     unsigned SignBitPosition = Val.getBitWidth() - 1;
170     APInt Zero = Val.One;
171     APInt One = Val.Zero;
172     Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
173     One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
174     return KnownBits(Zero, One);
175   };
176   return Flip(umax(Flip(LHS), Flip(RHS)));
177 }
178 
getMaxShiftAmount(const APInt & MaxValue,unsigned BitWidth)179 static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
180   if (isPowerOf2_32(BitWidth))
181     return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0);
182   // This is only an approximate upper bound.
183   return MaxValue.getLimitedValue(BitWidth - 1);
184 }
185 
shl(const KnownBits & LHS,const KnownBits & RHS,bool NUW,bool NSW,bool ShAmtNonZero)186 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
187                          bool NSW, bool ShAmtNonZero) {
188   unsigned BitWidth = LHS.getBitWidth();
189   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
190     KnownBits Known;
191     bool ShiftedOutZero, ShiftedOutOne;
192     Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero);
193     Known.Zero.setLowBits(ShiftAmt);
194     Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne);
195 
196     // All cases returning poison have been handled by MaxShiftAmount already.
197     if (NSW) {
198       if (NUW && ShiftAmt != 0)
199         // NUW means we can assume anything shifted out was a zero.
200         ShiftedOutZero = true;
201 
202       if (ShiftedOutZero)
203         Known.makeNonNegative();
204       else if (ShiftedOutOne)
205         Known.makeNegative();
206     }
207     return Known;
208   };
209 
210   // Fast path for a common case when LHS is completely unknown.
211   KnownBits Known(BitWidth);
212   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
213   if (MinShiftAmount == 0 && ShAmtNonZero)
214     MinShiftAmount = 1;
215   if (LHS.isUnknown()) {
216     Known.Zero.setLowBits(MinShiftAmount);
217     if (NUW && NSW && MinShiftAmount != 0)
218       Known.makeNonNegative();
219     return Known;
220   }
221 
222   // Determine maximum shift amount, taking NUW/NSW flags into account.
223   APInt MaxValue = RHS.getMaxValue();
224   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
225   if (NUW && NSW)
226     MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1);
227   if (NUW)
228     MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros());
229   if (NSW)
230     MaxShiftAmount = std::min(
231         MaxShiftAmount,
232         std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1);
233 
234   // Fast path for common case where the shift amount is unknown.
235   if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
236       isPowerOf2_32(BitWidth)) {
237     Known.Zero.setLowBits(LHS.countMinTrailingZeros());
238     if (LHS.isAllOnes())
239       Known.One.setSignBit();
240     if (NSW) {
241       if (LHS.isNonNegative())
242         Known.makeNonNegative();
243       if (LHS.isNegative())
244         Known.makeNegative();
245     }
246     return Known;
247   }
248 
249   // Find the common bits from all possible shifts.
250   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
251   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
252   Known.Zero.setAllBits();
253   Known.One.setAllBits();
254   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
255        ++ShiftAmt) {
256     // Skip if the shift amount is impossible.
257     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
258         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
259       continue;
260     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
261     if (Known.isUnknown())
262       break;
263   }
264 
265   // All shift amounts may result in poison.
266   if (Known.hasConflict())
267     Known.setAllZero();
268   return Known;
269 }
270 
lshr(const KnownBits & LHS,const KnownBits & RHS,bool ShAmtNonZero)271 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
272                           bool ShAmtNonZero) {
273   unsigned BitWidth = LHS.getBitWidth();
274   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
275     KnownBits Known = LHS;
276     Known.Zero.lshrInPlace(ShiftAmt);
277     Known.One.lshrInPlace(ShiftAmt);
278     // High bits are known zero.
279     Known.Zero.setHighBits(ShiftAmt);
280     return Known;
281   };
282 
283   // Fast path for a common case when LHS is completely unknown.
284   KnownBits Known(BitWidth);
285   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
286   if (MinShiftAmount == 0 && ShAmtNonZero)
287     MinShiftAmount = 1;
288   if (LHS.isUnknown()) {
289     Known.Zero.setHighBits(MinShiftAmount);
290     return Known;
291   }
292 
293   // Find the common bits from all possible shifts.
294   APInt MaxValue = RHS.getMaxValue();
295   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
296   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
297   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
298   Known.Zero.setAllBits();
299   Known.One.setAllBits();
300   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
301        ++ShiftAmt) {
302     // Skip if the shift amount is impossible.
303     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
304         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
305       continue;
306     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
307     if (Known.isUnknown())
308       break;
309   }
310 
311   // All shift amounts may result in poison.
312   if (Known.hasConflict())
313     Known.setAllZero();
314   return Known;
315 }
316 
ashr(const KnownBits & LHS,const KnownBits & RHS,bool ShAmtNonZero)317 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
318                           bool ShAmtNonZero) {
319   unsigned BitWidth = LHS.getBitWidth();
320   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
321     KnownBits Known = LHS;
322     Known.Zero.ashrInPlace(ShiftAmt);
323     Known.One.ashrInPlace(ShiftAmt);
324     return Known;
325   };
326 
327   // Fast path for a common case when LHS is completely unknown.
328   KnownBits Known(BitWidth);
329   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
330   if (MinShiftAmount == 0 && ShAmtNonZero)
331     MinShiftAmount = 1;
332   if (LHS.isUnknown()) {
333     if (MinShiftAmount == BitWidth) {
334       // Always poison. Return zero because we don't like returning conflict.
335       Known.setAllZero();
336       return Known;
337     }
338     return Known;
339   }
340 
341   // Find the common bits from all possible shifts.
342   APInt MaxValue = RHS.getMaxValue();
343   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
344   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
345   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
346   Known.Zero.setAllBits();
347   Known.One.setAllBits();
348   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
349       ++ShiftAmt) {
350     // Skip if the shift amount is impossible.
351     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
352         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
353       continue;
354     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
355     if (Known.isUnknown())
356       break;
357   }
358 
359   // All shift amounts may result in poison.
360   if (Known.hasConflict())
361     Known.setAllZero();
362   return Known;
363 }
364 
eq(const KnownBits & LHS,const KnownBits & RHS)365 std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
366   if (LHS.isConstant() && RHS.isConstant())
367     return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
368   if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
369     return std::optional<bool>(false);
370   return std::nullopt;
371 }
372 
ne(const KnownBits & LHS,const KnownBits & RHS)373 std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
374   if (std::optional<bool> KnownEQ = eq(LHS, RHS))
375     return std::optional<bool>(!*KnownEQ);
376   return std::nullopt;
377 }
378 
ugt(const KnownBits & LHS,const KnownBits & RHS)379 std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
380   // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
381   if (LHS.getMaxValue().ule(RHS.getMinValue()))
382     return std::optional<bool>(false);
383   // LHS >u RHS -> true if umin(LHS) > umax(RHS)
384   if (LHS.getMinValue().ugt(RHS.getMaxValue()))
385     return std::optional<bool>(true);
386   return std::nullopt;
387 }
388 
uge(const KnownBits & LHS,const KnownBits & RHS)389 std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
390   if (std::optional<bool> IsUGT = ugt(RHS, LHS))
391     return std::optional<bool>(!*IsUGT);
392   return std::nullopt;
393 }
394 
ult(const KnownBits & LHS,const KnownBits & RHS)395 std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
396   return ugt(RHS, LHS);
397 }
398 
ule(const KnownBits & LHS,const KnownBits & RHS)399 std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
400   return uge(RHS, LHS);
401 }
402 
sgt(const KnownBits & LHS,const KnownBits & RHS)403 std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
404   // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
405   if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
406     return std::optional<bool>(false);
407   // LHS >s RHS -> true if smin(LHS) > smax(RHS)
408   if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
409     return std::optional<bool>(true);
410   return std::nullopt;
411 }
412 
sge(const KnownBits & LHS,const KnownBits & RHS)413 std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
414   if (std::optional<bool> KnownSGT = sgt(RHS, LHS))
415     return std::optional<bool>(!*KnownSGT);
416   return std::nullopt;
417 }
418 
slt(const KnownBits & LHS,const KnownBits & RHS)419 std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
420   return sgt(RHS, LHS);
421 }
422 
sle(const KnownBits & LHS,const KnownBits & RHS)423 std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
424   return sge(RHS, LHS);
425 }
426 
abs(bool IntMinIsPoison) const427 KnownBits KnownBits::abs(bool IntMinIsPoison) const {
428   // If the source's MSB is zero then we know the rest of the bits already.
429   if (isNonNegative())
430     return *this;
431 
432   // Absolute value preserves trailing zero count.
433   KnownBits KnownAbs(getBitWidth());
434 
435   // If the input is negative, then abs(x) == -x.
436   if (isNegative()) {
437     KnownBits Tmp = *this;
438     // Special case for IntMinIsPoison. We know the sign bit is set and we know
439     // all the rest of the bits except one to be zero. Since we have
440     // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
441     // INT_MIN.
442     if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth())
443       Tmp.One.setBit(countMinTrailingZeros());
444 
445     KnownAbs = computeForAddSub(
446         /*Add*/ false, IntMinIsPoison,
447         KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);
448 
449     // One more special case for IntMinIsPoison. If we don't know any ones other
450     // than the signbit, we know for certain that all the unknowns can't be
451     // zero. So if we know high zero bits, but have unknown low bits, we know
452     // for certain those high-zero bits will end up as one. This is because,
453     // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
454     // to the high bits. If we know a known INT_MIN input skip this. The result
455     // is poison anyways.
456     if (IntMinIsPoison && Tmp.countMinPopulation() == 1 &&
457         Tmp.countMaxPopulation() != 1) {
458       Tmp.One.clearSignBit();
459       Tmp.Zero.setSignBit();
460       KnownAbs.One.setBits(getBitWidth() - Tmp.countMinLeadingZeros(),
461                            getBitWidth() - 1);
462     }
463 
464   } else {
465     unsigned MaxTZ = countMaxTrailingZeros();
466     unsigned MinTZ = countMinTrailingZeros();
467 
468     KnownAbs.Zero.setLowBits(MinTZ);
469     // If we know the lowest set 1, then preserve it.
470     if (MaxTZ == MinTZ && MaxTZ < getBitWidth())
471       KnownAbs.One.setBit(MaxTZ);
472 
473     // We only know that the absolute values's MSB will be zero if INT_MIN is
474     // poison, or there is a set bit that isn't the sign bit (otherwise it could
475     // be INT_MIN).
476     if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) {
477       KnownAbs.One.clearSignBit();
478       KnownAbs.Zero.setSignBit();
479     }
480   }
481 
482   assert(!KnownAbs.hasConflict() && "Bad Output");
483   return KnownAbs;
484 }
485 
computeForSatAddSub(bool Add,bool Signed,const KnownBits & LHS,const KnownBits & RHS)486 static KnownBits computeForSatAddSub(bool Add, bool Signed,
487                                      const KnownBits &LHS,
488                                      const KnownBits &RHS) {
489   assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
490   // We don't see NSW even for sadd/ssub as we want to check if the result has
491   // signed overflow.
492   KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS);
493   unsigned BitWidth = Res.getBitWidth();
494   auto SignBitKnown = [&](const KnownBits &K) {
495     return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
496   };
497   std::optional<bool> Overflow;
498 
499   if (Signed) {
500     // If we can actually detect overflow do so. Otherwise leave Overflow as
501     // nullopt (we assume it may have happened).
502     if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
503       if (Add) {
504         // sadd.sat
505         Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
506                     Res.isNonNegative() != LHS.isNonNegative());
507       } else {
508         // ssub.sat
509         Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
510                     Res.isNonNegative() != LHS.isNonNegative());
511       }
512     }
513   } else if (Add) {
514     // uadd.sat
515     bool Of;
516     (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of);
517     if (!Of) {
518       Overflow = false;
519     } else {
520       (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of);
521       if (Of)
522         Overflow = true;
523     }
524   } else {
525     // usub.sat
526     bool Of;
527     (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of);
528     if (!Of) {
529       Overflow = false;
530     } else {
531       (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of);
532       if (Of)
533         Overflow = true;
534     }
535   }
536 
537   if (Signed) {
538     if (Add) {
539       if (LHS.isNonNegative() && RHS.isNonNegative()) {
540         // Pos + Pos -> Pos
541         Res.One.clearSignBit();
542         Res.Zero.setSignBit();
543       }
544       if (LHS.isNegative() && RHS.isNegative()) {
545         // Neg + Neg -> Neg
546         Res.One.setSignBit();
547         Res.Zero.clearSignBit();
548       }
549     } else {
550       if (LHS.isNegative() && RHS.isNonNegative()) {
551         // Neg - Pos -> Neg
552         Res.One.setSignBit();
553         Res.Zero.clearSignBit();
554       } else if (LHS.isNonNegative() && RHS.isNegative()) {
555         // Pos - Neg -> Pos
556         Res.One.clearSignBit();
557         Res.Zero.setSignBit();
558       }
559     }
560   } else {
561     // Add: Leading ones of either operand are preserved.
562     // Sub: Leading zeros of LHS and leading ones of RHS are preserved
563     // as leading zeros in the result.
564     unsigned LeadingKnown;
565     if (Add)
566       LeadingKnown =
567           std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
568     else
569       LeadingKnown =
570           std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
571 
572     // We select between the operation result and all-ones/zero
573     // respectively, so we can preserve known ones/zeros.
574     APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
575     if (Add) {
576       Res.One |= Mask;
577       Res.Zero &= ~Mask;
578     } else {
579       Res.Zero |= Mask;
580       Res.One &= ~Mask;
581     }
582   }
583 
584   if (Overflow) {
585     // We know whether or not we overflowed.
586     if (!(*Overflow)) {
587       // No overflow.
588       assert(!Res.hasConflict() && "Bad Output");
589       return Res;
590     }
591 
592     // We overflowed
593     APInt C;
594     if (Signed) {
595       // sadd.sat / ssub.sat
596       assert(SignBitKnown(LHS) &&
597              "We somehow know overflow without knowing input sign");
598       C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
599                            : APInt::getSignedMaxValue(BitWidth);
600     } else if (Add) {
601       // uadd.sat
602       C = APInt::getMaxValue(BitWidth);
603     } else {
604       // uadd.sat
605       C = APInt::getMinValue(BitWidth);
606     }
607 
608     Res.One = C;
609     Res.Zero = ~C;
610     assert(!Res.hasConflict() && "Bad Output");
611     return Res;
612   }
613 
614   // We don't know if we overflowed.
615   if (Signed) {
616     // sadd.sat/ssub.sat
617     // We can keep our information about the sign bits.
618     Res.Zero.clearLowBits(BitWidth - 1);
619     Res.One.clearLowBits(BitWidth - 1);
620   } else if (Add) {
621     // uadd.sat
622     // We need to clear all the known zeros as we can only use the leading ones.
623     Res.Zero.clearAllBits();
624   } else {
625     // usub.sat
626     // We need to clear all the known ones as we can only use the leading zero.
627     Res.One.clearAllBits();
628   }
629 
630   assert(!Res.hasConflict() && "Bad Output");
631   return Res;
632 }
633 
sadd_sat(const KnownBits & LHS,const KnownBits & RHS)634 KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
635   return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
636 }
ssub_sat(const KnownBits & LHS,const KnownBits & RHS)637 KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
638   return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
639 }
uadd_sat(const KnownBits & LHS,const KnownBits & RHS)640 KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
641   return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
642 }
usub_sat(const KnownBits & LHS,const KnownBits & RHS)643 KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
644   return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
645 }
646 
mul(const KnownBits & LHS,const KnownBits & RHS,bool NoUndefSelfMultiply)647 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
648                          bool NoUndefSelfMultiply) {
649   unsigned BitWidth = LHS.getBitWidth();
650   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
651          !RHS.hasConflict() && "Operand mismatch");
652   assert((!NoUndefSelfMultiply || LHS == RHS) &&
653          "Self multiplication knownbits mismatch");
654 
655   // Compute the high known-0 bits by multiplying the unsigned max of each side.
656   // Conservatively, M active bits * N active bits results in M + N bits in the
657   // result. But if we know a value is a power-of-2 for example, then this
658   // computes one more leading zero.
659   // TODO: This could be generalized to number of sign bits (negative numbers).
660   APInt UMaxLHS = LHS.getMaxValue();
661   APInt UMaxRHS = RHS.getMaxValue();
662 
663   // For leading zeros in the result to be valid, the unsigned max product must
664   // fit in the bitwidth (it must not overflow).
665   bool HasOverflow;
666   APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
667   unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
668 
669   // The result of the bottom bits of an integer multiply can be
670   // inferred by looking at the bottom bits of both operands and
671   // multiplying them together.
672   // We can infer at least the minimum number of known trailing bits
673   // of both operands. Depending on number of trailing zeros, we can
674   // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
675   // a and b are divisible by m and n respectively.
676   // We then calculate how many of those bits are inferrable and set
677   // the output. For example, the i8 mul:
678   //  a = XXXX1100 (12)
679   //  b = XXXX1110 (14)
680   // We know the bottom 3 bits are zero since the first can be divided by
681   // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
682   // Applying the multiplication to the trimmed arguments gets:
683   //    XX11 (3)
684   //    X111 (7)
685   // -------
686   //    XX11
687   //   XX11
688   //  XX11
689   // XX11
690   // -------
691   // XXXXX01
692   // Which allows us to infer the 2 LSBs. Since we're multiplying the result
693   // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
694   // The proof for this can be described as:
695   // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
696   //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
697   //                    umin(countTrailingZeros(C2), C6) +
698   //                    umin(C5 - umin(countTrailingZeros(C1), C5),
699   //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
700   // %aa = shl i8 %a, C5
701   // %bb = shl i8 %b, C6
702   // %aaa = or i8 %aa, C1
703   // %bbb = or i8 %bb, C2
704   // %mul = mul i8 %aaa, %bbb
705   // %mask = and i8 %mul, C7
706   //   =>
707   // %mask = i8 ((C1*C2)&C7)
708   // Where C5, C6 describe the known bits of %a, %b
709   // C1, C2 describe the known bottom bits of %a, %b.
710   // C7 describes the mask of the known bits of the result.
711   const APInt &Bottom0 = LHS.One;
712   const APInt &Bottom1 = RHS.One;
713 
714   // How many times we'd be able to divide each argument by 2 (shr by 1).
715   // This gives us the number of trailing zeros on the multiplication result.
716   unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one();
717   unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one();
718   unsigned TrailZero0 = LHS.countMinTrailingZeros();
719   unsigned TrailZero1 = RHS.countMinTrailingZeros();
720   unsigned TrailZ = TrailZero0 + TrailZero1;
721 
722   // Figure out the fewest known-bits operand.
723   unsigned SmallestOperand =
724       std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
725   unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
726 
727   APInt BottomKnown =
728       Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
729 
730   KnownBits Res(BitWidth);
731   Res.Zero.setHighBits(LeadZ);
732   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
733   Res.One = BottomKnown.getLoBits(ResultBitsKnown);
734 
735   // If we're self-multiplying then bit[1] is guaranteed to be zero.
736   if (NoUndefSelfMultiply && BitWidth > 1) {
737     assert(Res.One[1] == 0 &&
738            "Self-multiplication failed Quadratic Reciprocity!");
739     Res.Zero.setBit(1);
740   }
741 
742   return Res;
743 }
744 
mulhs(const KnownBits & LHS,const KnownBits & RHS)745 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
746   unsigned BitWidth = LHS.getBitWidth();
747   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
748          !RHS.hasConflict() && "Operand mismatch");
749   KnownBits WideLHS = LHS.sext(2 * BitWidth);
750   KnownBits WideRHS = RHS.sext(2 * BitWidth);
751   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
752 }
753 
mulhu(const KnownBits & LHS,const KnownBits & RHS)754 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
755   unsigned BitWidth = LHS.getBitWidth();
756   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
757          !RHS.hasConflict() && "Operand mismatch");
758   KnownBits WideLHS = LHS.zext(2 * BitWidth);
759   KnownBits WideRHS = RHS.zext(2 * BitWidth);
760   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
761 }
762 
divComputeLowBit(KnownBits Known,const KnownBits & LHS,const KnownBits & RHS,bool Exact)763 static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS,
764                                   const KnownBits &RHS, bool Exact) {
765 
766   if (!Exact)
767     return Known;
768 
769   // If LHS is Odd, the result is Odd no matter what.
770   // Odd / Odd -> Odd
771   // Odd / Even -> Impossible (because its exact division)
772   if (LHS.One[0])
773     Known.One.setBit(0);
774 
775   int MinTZ =
776       (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros();
777   int MaxTZ =
778       (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros();
779   if (MinTZ >= 0) {
780     // Result has at least MinTZ trailing zeros.
781     Known.Zero.setLowBits(MinTZ);
782     if (MinTZ == MaxTZ) {
783       // Result has exactly MinTZ trailing zeros.
784       Known.One.setBit(MinTZ);
785     }
786   } else if (MaxTZ < 0) {
787     // Poison Result
788     Known.setAllZero();
789   }
790 
791   // In the KnownBits exhaustive tests, we have poison inputs for exact values
792   // a LOT. If we have a conflict, just return all zeros.
793   if (Known.hasConflict())
794     Known.setAllZero();
795 
796   return Known;
797 }
798 
sdiv(const KnownBits & LHS,const KnownBits & RHS,bool Exact)799 KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
800                           bool Exact) {
801   // Equivalent of `udiv`. We must have caught this before it was folded.
802   if (LHS.isNonNegative() && RHS.isNonNegative())
803     return udiv(LHS, RHS, Exact);
804 
805   unsigned BitWidth = LHS.getBitWidth();
806   assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
807   KnownBits Known(BitWidth);
808 
809   if (LHS.isZero() || RHS.isZero()) {
810     // Result is either known Zero or UB. Return Zero either way.
811     // Checking this earlier saves us a lot of special cases later on.
812     Known.setAllZero();
813     return Known;
814   }
815 
816   std::optional<APInt> Res;
817   if (LHS.isNegative() && RHS.isNegative()) {
818     // Result non-negative.
819     APInt Denom = RHS.getSignedMaxValue();
820     APInt Num = LHS.getSignedMinValue();
821     // INT_MIN/-1 would be a poison result (impossible). Estimate the division
822     // as signed max (we will only set sign bit in the result).
823     Res = (Num.isMinSignedValue() && Denom.isAllOnes())
824               ? APInt::getSignedMaxValue(BitWidth)
825               : Num.sdiv(Denom);
826   } else if (LHS.isNegative() && RHS.isNonNegative()) {
827     // Result is negative if Exact OR -LHS u>= RHS.
828     if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) {
829       APInt Denom = RHS.getSignedMinValue();
830       APInt Num = LHS.getSignedMinValue();
831       Res = Denom.isZero() ? Num : Num.sdiv(Denom);
832     }
833   } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
834     // Result is negative if Exact OR LHS u>= -RHS.
835     if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) {
836       APInt Denom = RHS.getSignedMaxValue();
837       APInt Num = LHS.getSignedMaxValue();
838       Res = Num.sdiv(Denom);
839     }
840   }
841 
842   if (Res) {
843     if (Res->isNonNegative()) {
844       unsigned LeadZ = Res->countLeadingZeros();
845       Known.Zero.setHighBits(LeadZ);
846     } else {
847       unsigned LeadO = Res->countLeadingOnes();
848       Known.One.setHighBits(LeadO);
849     }
850   }
851 
852   Known = divComputeLowBit(Known, LHS, RHS, Exact);
853 
854   assert(!Known.hasConflict() && "Bad Output");
855   return Known;
856 }
857 
udiv(const KnownBits & LHS,const KnownBits & RHS,bool Exact)858 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
859                           bool Exact) {
860   unsigned BitWidth = LHS.getBitWidth();
861   assert(!LHS.hasConflict() && !RHS.hasConflict());
862   KnownBits Known(BitWidth);
863 
864   if (LHS.isZero() || RHS.isZero()) {
865     // Result is either known Zero or UB. Return Zero either way.
866     // Checking this earlier saves us a lot of special cases later on.
867     Known.setAllZero();
868     return Known;
869   }
870 
871   // We can figure out the minimum number of upper zero bits by doing
872   // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
873   // gets larger, the number of upper zero bits increases.
874   APInt MinDenom = RHS.getMinValue();
875   APInt MaxNum = LHS.getMaxValue();
876   APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom);
877 
878   unsigned LeadZ = MaxRes.countLeadingZeros();
879 
880   Known.Zero.setHighBits(LeadZ);
881   Known = divComputeLowBit(Known, LHS, RHS, Exact);
882 
883   assert(!Known.hasConflict() && "Bad Output");
884   return Known;
885 }
886 
remGetLowBits(const KnownBits & LHS,const KnownBits & RHS)887 KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) {
888   unsigned BitWidth = LHS.getBitWidth();
889   if (!RHS.isZero() && RHS.Zero[0]) {
890     // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
891     unsigned RHSZeros = RHS.countMinTrailingZeros();
892     APInt Mask = APInt::getLowBitsSet(BitWidth, RHSZeros);
893     APInt OnesMask = LHS.One & Mask;
894     APInt ZerosMask = LHS.Zero & Mask;
895     return KnownBits(ZerosMask, OnesMask);
896   }
897   return KnownBits(BitWidth);
898 }
899 
urem(const KnownBits & LHS,const KnownBits & RHS)900 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
901   assert(!LHS.hasConflict() && !RHS.hasConflict());
902 
903   KnownBits Known = remGetLowBits(LHS, RHS);
904   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
905     // NB: Low bits set in `remGetLowBits`.
906     APInt HighBits = ~(RHS.getConstant() - 1);
907     Known.Zero |= HighBits;
908     return Known;
909   }
910 
911   // Since the result is less than or equal to either operand, any leading
912   // zero bits in either operand must also exist in the result.
913   uint32_t Leaders =
914       std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros());
915   Known.Zero.setHighBits(Leaders);
916   return Known;
917 }
918 
srem(const KnownBits & LHS,const KnownBits & RHS)919 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
920   assert(!LHS.hasConflict() && !RHS.hasConflict());
921 
922   KnownBits Known = remGetLowBits(LHS, RHS);
923   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
924     // NB: Low bits are set in `remGetLowBits`.
925     APInt LowBits = RHS.getConstant() - 1;
926     // If the first operand is non-negative or has all low bits zero, then
927     // the upper bits are all zero.
928     if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero))
929       Known.Zero |= ~LowBits;
930 
931     // If the first operand is negative and not all low bits are zero, then
932     // the upper bits are all one.
933     if (LHS.isNegative() && LowBits.intersects(LHS.One))
934       Known.One |= ~LowBits;
935     return Known;
936   }
937 
938   // The sign bit is the LHS's sign bit, except when the result of the
939   // remainder is zero. The magnitude of the result should be less than or
940   // equal to the magnitude of the LHS. Therefore any leading zeros that exist
941   // in the left hand side must also exist in the result.
942   Known.Zero.setHighBits(LHS.countMinLeadingZeros());
943   return Known;
944 }
945 
operator &=(const KnownBits & RHS)946 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
947   // Result bit is 0 if either operand bit is 0.
948   Zero |= RHS.Zero;
949   // Result bit is 1 if both operand bits are 1.
950   One &= RHS.One;
951   return *this;
952 }
953 
operator |=(const KnownBits & RHS)954 KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
955   // Result bit is 0 if both operand bits are 0.
956   Zero &= RHS.Zero;
957   // Result bit is 1 if either operand bit is 1.
958   One |= RHS.One;
959   return *this;
960 }
961 
operator ^=(const KnownBits & RHS)962 KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
963   // Result bit is 0 if both operand bits are 0 or both are 1.
964   APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
965   // Result bit is 1 if one operand bit is 0 and the other is 1.
966   One = (Zero & RHS.One) | (One & RHS.Zero);
967   Zero = std::move(Z);
968   return *this;
969 }
970 
blsi() const971 KnownBits KnownBits::blsi() const {
972   unsigned BitWidth = getBitWidth();
973   KnownBits Known(Zero, APInt(BitWidth, 0));
974   unsigned Max = countMaxTrailingZeros();
975   Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
976   unsigned Min = countMinTrailingZeros();
977   if (Max == Min && Max < BitWidth)
978     Known.One.setBit(Max);
979   return Known;
980 }
981 
blsmsk() const982 KnownBits KnownBits::blsmsk() const {
983   unsigned BitWidth = getBitWidth();
984   KnownBits Known(BitWidth);
985   unsigned Max = countMaxTrailingZeros();
986   Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
987   unsigned Min = countMinTrailingZeros();
988   Known.One.setLowBits(std::min(Min + 1, BitWidth));
989   return Known;
990 }
991 
print(raw_ostream & OS) const992 void KnownBits::print(raw_ostream &OS) const {
993   unsigned BitWidth = getBitWidth();
994   for (unsigned I = 0; I < BitWidth; ++I) {
995     unsigned N = BitWidth - I - 1;
996     if (Zero[N] && One[N])
997       OS << "!";
998     else if (Zero[N])
999       OS << "0";
1000     else if (One[N])
1001       OS << "1";
1002     else
1003       OS << "?";
1004   }
1005 }
dump() const1006 void KnownBits::dump() const {
1007   print(dbgs());
1008   dbgs() << "\n";
1009 }
1010