1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/KnownBits.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <cassert>
18 
19 using namespace llvm;
20 
21 static KnownBits computeForAddCarry(
22     const KnownBits &LHS, const KnownBits &RHS,
23     bool CarryZero, bool CarryOne) {
24   assert(!(CarryZero && CarryOne) &&
25          "Carry can't be zero and one at the same time");
26 
27   APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
28   APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
29 
30   // Compute known bits of the carry.
31   APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
32   APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
33 
34   // Compute set of known bits (where all three relevant bits are known).
35   APInt LHSKnownUnion = LHS.Zero | LHS.One;
36   APInt RHSKnownUnion = RHS.Zero | RHS.One;
37   APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
38   APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
39 
40   assert((PossibleSumZero & Known) == (PossibleSumOne & Known) &&
41          "known bits of sum differ");
42 
43   // Compute known bits of the result.
44   KnownBits KnownOut;
45   KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
46   KnownOut.One = std::move(PossibleSumOne) & Known;
47   return KnownOut;
48 }
49 
50 KnownBits KnownBits::computeForAddCarry(
51     const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
52   assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
53   return ::computeForAddCarry(
54       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
55 }
56 
57 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
58                                       const KnownBits &LHS, KnownBits RHS) {
59   KnownBits KnownOut;
60   if (Add) {
61     // Sum = LHS + RHS + 0
62     KnownOut = ::computeForAddCarry(
63         LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
64   } else {
65     // Sum = LHS + ~RHS + 1
66     std::swap(RHS.Zero, RHS.One);
67     KnownOut = ::computeForAddCarry(
68         LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
69   }
70 
71   // Are we still trying to solve for the sign bit?
72   if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
73     if (NSW) {
74       // Adding two non-negative numbers, or subtracting a negative number from
75       // a non-negative one, can't wrap into negative.
76       if (LHS.isNonNegative() && RHS.isNonNegative())
77         KnownOut.makeNonNegative();
78       // Adding two negative numbers, or subtracting a non-negative number from
79       // a negative one, can't wrap into non-negative.
80       else if (LHS.isNegative() && RHS.isNegative())
81         KnownOut.makeNegative();
82     }
83   }
84 
85   return KnownOut;
86 }
87 
88 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
89   unsigned BitWidth = getBitWidth();
90   assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
91          "Illegal sext-in-register");
92 
93   if (SrcBitWidth == BitWidth)
94     return *this;
95 
96   unsigned ExtBits = BitWidth - SrcBitWidth;
97   KnownBits Result;
98   Result.One = One << ExtBits;
99   Result.Zero = Zero << ExtBits;
100   Result.One.ashrInPlace(ExtBits);
101   Result.Zero.ashrInPlace(ExtBits);
102   return Result;
103 }
104 
105 KnownBits KnownBits::makeGE(const APInt &Val) const {
106   // Count the number of leading bit positions where our underlying value is
107   // known to be less than or equal to Val.
108   unsigned N = (Zero | Val).countl_one();
109 
110   // For each of those bit positions, if Val has a 1 in that bit then our
111   // underlying value must also have a 1.
112   APInt MaskedVal(Val);
113   MaskedVal.clearLowBits(getBitWidth() - N);
114   return KnownBits(Zero, One | MaskedVal);
115 }
116 
117 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
118   // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
119   // RHS. Ideally our caller would already have spotted these cases and
120   // optimized away the umax operation, but we handle them here for
121   // completeness.
122   if (LHS.getMinValue().uge(RHS.getMaxValue()))
123     return LHS;
124   if (RHS.getMinValue().uge(LHS.getMaxValue()))
125     return RHS;
126 
127   // If the result of the umax is LHS then it must be greater than or equal to
128   // the minimum possible value of RHS. Likewise for RHS. Any known bits that
129   // are common to these two values are also known in the result.
130   KnownBits L = LHS.makeGE(RHS.getMinValue());
131   KnownBits R = RHS.makeGE(LHS.getMinValue());
132   return L.intersectWith(R);
133 }
134 
135 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
136   // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
137   auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
138   return Flip(umax(Flip(LHS), Flip(RHS)));
139 }
140 
141 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
142   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
143   auto Flip = [](const KnownBits &Val) {
144     unsigned SignBitPosition = Val.getBitWidth() - 1;
145     APInt Zero = Val.Zero;
146     APInt One = Val.One;
147     Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
148     One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
149     return KnownBits(Zero, One);
150   };
151   return Flip(umax(Flip(LHS), Flip(RHS)));
152 }
153 
154 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
155   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
156   auto Flip = [](const KnownBits &Val) {
157     unsigned SignBitPosition = Val.getBitWidth() - 1;
158     APInt Zero = Val.One;
159     APInt One = Val.Zero;
160     Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
161     One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
162     return KnownBits(Zero, One);
163   };
164   return Flip(umax(Flip(LHS), Flip(RHS)));
165 }
166 
167 static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
168   if (isPowerOf2_32(BitWidth))
169     return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0);
170   // This is only an approximate upper bound.
171   return MaxValue.getLimitedValue(BitWidth - 1);
172 }
173 
174 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
175                          bool NSW, bool ShAmtNonZero) {
176   unsigned BitWidth = LHS.getBitWidth();
177   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
178     KnownBits Known;
179     bool ShiftedOutZero, ShiftedOutOne;
180     Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero);
181     Known.Zero.setLowBits(ShiftAmt);
182     Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne);
183 
184     // All cases returning poison have been handled by MaxShiftAmount already.
185     if (NSW) {
186       if (NUW && ShiftAmt != 0)
187         // NUW means we can assume anything shifted out was a zero.
188         ShiftedOutZero = true;
189 
190       if (ShiftedOutZero)
191         Known.makeNonNegative();
192       else if (ShiftedOutOne)
193         Known.makeNegative();
194     }
195     return Known;
196   };
197 
198   // Fast path for a common case when LHS is completely unknown.
199   KnownBits Known(BitWidth);
200   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
201   if (MinShiftAmount == 0 && ShAmtNonZero)
202     MinShiftAmount = 1;
203   if (LHS.isUnknown()) {
204     Known.Zero.setLowBits(MinShiftAmount);
205     if (NUW && NSW && MinShiftAmount != 0)
206       Known.makeNonNegative();
207     return Known;
208   }
209 
210   // Determine maximum shift amount, taking NUW/NSW flags into account.
211   APInt MaxValue = RHS.getMaxValue();
212   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
213   if (NUW && NSW)
214     MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1);
215   if (NUW)
216     MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros());
217   if (NSW)
218     MaxShiftAmount = std::min(
219         MaxShiftAmount,
220         std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1);
221 
222   // Fast path for common case where the shift amount is unknown.
223   if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
224       isPowerOf2_32(BitWidth)) {
225     Known.Zero.setLowBits(LHS.countMinTrailingZeros());
226     if (LHS.isAllOnes())
227       Known.One.setSignBit();
228     if (NSW) {
229       if (LHS.isNonNegative())
230         Known.makeNonNegative();
231       if (LHS.isNegative())
232         Known.makeNegative();
233     }
234     return Known;
235   }
236 
237   // Find the common bits from all possible shifts.
238   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
239   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
240   Known.Zero.setAllBits();
241   Known.One.setAllBits();
242   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
243        ++ShiftAmt) {
244     // Skip if the shift amount is impossible.
245     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
246         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
247       continue;
248     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
249     if (Known.isUnknown())
250       break;
251   }
252 
253   // All shift amounts may result in poison.
254   if (Known.hasConflict())
255     Known.setAllZero();
256   return Known;
257 }
258 
259 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
260                           bool ShAmtNonZero) {
261   unsigned BitWidth = LHS.getBitWidth();
262   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
263     KnownBits Known = LHS;
264     Known.Zero.lshrInPlace(ShiftAmt);
265     Known.One.lshrInPlace(ShiftAmt);
266     // High bits are known zero.
267     Known.Zero.setHighBits(ShiftAmt);
268     return Known;
269   };
270 
271   // Fast path for a common case when LHS is completely unknown.
272   KnownBits Known(BitWidth);
273   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
274   if (MinShiftAmount == 0 && ShAmtNonZero)
275     MinShiftAmount = 1;
276   if (LHS.isUnknown()) {
277     Known.Zero.setHighBits(MinShiftAmount);
278     return Known;
279   }
280 
281   // Find the common bits from all possible shifts.
282   APInt MaxValue = RHS.getMaxValue();
283   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
284   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
285   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
286   Known.Zero.setAllBits();
287   Known.One.setAllBits();
288   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
289        ++ShiftAmt) {
290     // Skip if the shift amount is impossible.
291     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
292         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
293       continue;
294     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
295     if (Known.isUnknown())
296       break;
297   }
298 
299   // All shift amounts may result in poison.
300   if (Known.hasConflict())
301     Known.setAllZero();
302   return Known;
303 }
304 
305 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
306                           bool ShAmtNonZero) {
307   unsigned BitWidth = LHS.getBitWidth();
308   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
309     KnownBits Known = LHS;
310     Known.Zero.ashrInPlace(ShiftAmt);
311     Known.One.ashrInPlace(ShiftAmt);
312     return Known;
313   };
314 
315   // Fast path for a common case when LHS is completely unknown.
316   KnownBits Known(BitWidth);
317   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
318   if (MinShiftAmount == 0 && ShAmtNonZero)
319     MinShiftAmount = 1;
320   if (LHS.isUnknown()) {
321     if (MinShiftAmount == BitWidth) {
322       // Always poison. Return zero because we don't like returning conflict.
323       Known.setAllZero();
324       return Known;
325     }
326     return Known;
327   }
328 
329   // Find the common bits from all possible shifts.
330   APInt MaxValue = RHS.getMaxValue();
331   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
332   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
333   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
334   Known.Zero.setAllBits();
335   Known.One.setAllBits();
336   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
337       ++ShiftAmt) {
338     // Skip if the shift amount is impossible.
339     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
340         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
341       continue;
342     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
343     if (Known.isUnknown())
344       break;
345   }
346 
347   // All shift amounts may result in poison.
348   if (Known.hasConflict())
349     Known.setAllZero();
350   return Known;
351 }
352 
353 std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
354   if (LHS.isConstant() && RHS.isConstant())
355     return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
356   if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
357     return std::optional<bool>(false);
358   return std::nullopt;
359 }
360 
361 std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
362   if (std::optional<bool> KnownEQ = eq(LHS, RHS))
363     return std::optional<bool>(!*KnownEQ);
364   return std::nullopt;
365 }
366 
367 std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
368   // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
369   if (LHS.getMaxValue().ule(RHS.getMinValue()))
370     return std::optional<bool>(false);
371   // LHS >u RHS -> true if umin(LHS) > umax(RHS)
372   if (LHS.getMinValue().ugt(RHS.getMaxValue()))
373     return std::optional<bool>(true);
374   return std::nullopt;
375 }
376 
377 std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
378   if (std::optional<bool> IsUGT = ugt(RHS, LHS))
379     return std::optional<bool>(!*IsUGT);
380   return std::nullopt;
381 }
382 
383 std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
384   return ugt(RHS, LHS);
385 }
386 
387 std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
388   return uge(RHS, LHS);
389 }
390 
391 std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
392   // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
393   if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
394     return std::optional<bool>(false);
395   // LHS >s RHS -> true if smin(LHS) > smax(RHS)
396   if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
397     return std::optional<bool>(true);
398   return std::nullopt;
399 }
400 
401 std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
402   if (std::optional<bool> KnownSGT = sgt(RHS, LHS))
403     return std::optional<bool>(!*KnownSGT);
404   return std::nullopt;
405 }
406 
407 std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
408   return sgt(RHS, LHS);
409 }
410 
411 std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
412   return sge(RHS, LHS);
413 }
414 
415 KnownBits KnownBits::abs(bool IntMinIsPoison) const {
416   // If the source's MSB is zero then we know the rest of the bits already.
417   if (isNonNegative())
418     return *this;
419 
420   // Absolute value preserves trailing zero count.
421   KnownBits KnownAbs(getBitWidth());
422 
423   // If the input is negative, then abs(x) == -x.
424   if (isNegative()) {
425     KnownBits Tmp = *this;
426     // Special case for IntMinIsPoison. We know the sign bit is set and we know
427     // all the rest of the bits except one to be zero. Since we have
428     // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
429     // INT_MIN.
430     if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth())
431       Tmp.One.setBit(countMinTrailingZeros());
432 
433     KnownAbs = computeForAddSub(
434         /*Add*/ false, IntMinIsPoison,
435         KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);
436 
437     // One more special case for IntMinIsPoison. If we don't know any ones other
438     // than the signbit, we know for certain that all the unknowns can't be
439     // zero. So if we know high zero bits, but have unknown low bits, we know
440     // for certain those high-zero bits will end up as one. This is because,
441     // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
442     // to the high bits. If we know a known INT_MIN input skip this. The result
443     // is poison anyways.
444     if (IntMinIsPoison && Tmp.countMinPopulation() == 1 &&
445         Tmp.countMaxPopulation() != 1) {
446       Tmp.One.clearSignBit();
447       Tmp.Zero.setSignBit();
448       KnownAbs.One.setBits(getBitWidth() - Tmp.countMinLeadingZeros(),
449                            getBitWidth() - 1);
450     }
451 
452   } else {
453     unsigned MaxTZ = countMaxTrailingZeros();
454     unsigned MinTZ = countMinTrailingZeros();
455 
456     KnownAbs.Zero.setLowBits(MinTZ);
457     // If we know the lowest set 1, then preserve it.
458     if (MaxTZ == MinTZ && MaxTZ < getBitWidth())
459       KnownAbs.One.setBit(MaxTZ);
460 
461     // We only know that the absolute values's MSB will be zero if INT_MIN is
462     // poison, or there is a set bit that isn't the sign bit (otherwise it could
463     // be INT_MIN).
464     if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) {
465       KnownAbs.One.clearSignBit();
466       KnownAbs.Zero.setSignBit();
467     }
468   }
469 
470   assert(!KnownAbs.hasConflict() && "Bad Output");
471   return KnownAbs;
472 }
473 
474 static KnownBits computeForSatAddSub(bool Add, bool Signed,
475                                      const KnownBits &LHS,
476                                      const KnownBits &RHS) {
477   assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
478   // We don't see NSW even for sadd/ssub as we want to check if the result has
479   // signed overflow.
480   KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS);
481   unsigned BitWidth = Res.getBitWidth();
482   auto SignBitKnown = [&](const KnownBits &K) {
483     return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
484   };
485   std::optional<bool> Overflow;
486 
487   if (Signed) {
488     // If we can actually detect overflow do so. Otherwise leave Overflow as
489     // nullopt (we assume it may have happened).
490     if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
491       if (Add) {
492         // sadd.sat
493         Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
494                     Res.isNonNegative() != LHS.isNonNegative());
495       } else {
496         // ssub.sat
497         Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
498                     Res.isNonNegative() != LHS.isNonNegative());
499       }
500     }
501   } else if (Add) {
502     // uadd.sat
503     bool Of;
504     (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of);
505     if (!Of) {
506       Overflow = false;
507     } else {
508       (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of);
509       if (Of)
510         Overflow = true;
511     }
512   } else {
513     // usub.sat
514     bool Of;
515     (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of);
516     if (!Of) {
517       Overflow = false;
518     } else {
519       (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of);
520       if (Of)
521         Overflow = true;
522     }
523   }
524 
525   if (Signed) {
526     if (Add) {
527       if (LHS.isNonNegative() && RHS.isNonNegative()) {
528         // Pos + Pos -> Pos
529         Res.One.clearSignBit();
530         Res.Zero.setSignBit();
531       }
532       if (LHS.isNegative() && RHS.isNegative()) {
533         // Neg + Neg -> Neg
534         Res.One.setSignBit();
535         Res.Zero.clearSignBit();
536       }
537     } else {
538       if (LHS.isNegative() && RHS.isNonNegative()) {
539         // Neg - Pos -> Neg
540         Res.One.setSignBit();
541         Res.Zero.clearSignBit();
542       } else if (LHS.isNonNegative() && RHS.isNegative()) {
543         // Pos - Neg -> Pos
544         Res.One.clearSignBit();
545         Res.Zero.setSignBit();
546       }
547     }
548   } else {
549     // Add: Leading ones of either operand are preserved.
550     // Sub: Leading zeros of LHS and leading ones of RHS are preserved
551     // as leading zeros in the result.
552     unsigned LeadingKnown;
553     if (Add)
554       LeadingKnown =
555           std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
556     else
557       LeadingKnown =
558           std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
559 
560     // We select between the operation result and all-ones/zero
561     // respectively, so we can preserve known ones/zeros.
562     APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
563     if (Add) {
564       Res.One |= Mask;
565       Res.Zero &= ~Mask;
566     } else {
567       Res.Zero |= Mask;
568       Res.One &= ~Mask;
569     }
570   }
571 
572   if (Overflow) {
573     // We know whether or not we overflowed.
574     if (!(*Overflow)) {
575       // No overflow.
576       assert(!Res.hasConflict() && "Bad Output");
577       return Res;
578     }
579 
580     // We overflowed
581     APInt C;
582     if (Signed) {
583       // sadd.sat / ssub.sat
584       assert(SignBitKnown(LHS) &&
585              "We somehow know overflow without knowing input sign");
586       C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
587                            : APInt::getSignedMaxValue(BitWidth);
588     } else if (Add) {
589       // uadd.sat
590       C = APInt::getMaxValue(BitWidth);
591     } else {
592       // uadd.sat
593       C = APInt::getMinValue(BitWidth);
594     }
595 
596     Res.One = C;
597     Res.Zero = ~C;
598     assert(!Res.hasConflict() && "Bad Output");
599     return Res;
600   }
601 
602   // We don't know if we overflowed.
603   if (Signed) {
604     // sadd.sat/ssub.sat
605     // We can keep our information about the sign bits.
606     Res.Zero.clearLowBits(BitWidth - 1);
607     Res.One.clearLowBits(BitWidth - 1);
608   } else if (Add) {
609     // uadd.sat
610     // We need to clear all the known zeros as we can only use the leading ones.
611     Res.Zero.clearAllBits();
612   } else {
613     // usub.sat
614     // We need to clear all the known ones as we can only use the leading zero.
615     Res.One.clearAllBits();
616   }
617 
618   assert(!Res.hasConflict() && "Bad Output");
619   return Res;
620 }
621 
622 KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
623   return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
624 }
625 KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
626   return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
627 }
628 KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
629   return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
630 }
631 KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
632   return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
633 }
634 
635 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
636                          bool NoUndefSelfMultiply) {
637   unsigned BitWidth = LHS.getBitWidth();
638   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
639          !RHS.hasConflict() && "Operand mismatch");
640   assert((!NoUndefSelfMultiply || LHS == RHS) &&
641          "Self multiplication knownbits mismatch");
642 
643   // Compute the high known-0 bits by multiplying the unsigned max of each side.
644   // Conservatively, M active bits * N active bits results in M + N bits in the
645   // result. But if we know a value is a power-of-2 for example, then this
646   // computes one more leading zero.
647   // TODO: This could be generalized to number of sign bits (negative numbers).
648   APInt UMaxLHS = LHS.getMaxValue();
649   APInt UMaxRHS = RHS.getMaxValue();
650 
651   // For leading zeros in the result to be valid, the unsigned max product must
652   // fit in the bitwidth (it must not overflow).
653   bool HasOverflow;
654   APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
655   unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
656 
657   // The result of the bottom bits of an integer multiply can be
658   // inferred by looking at the bottom bits of both operands and
659   // multiplying them together.
660   // We can infer at least the minimum number of known trailing bits
661   // of both operands. Depending on number of trailing zeros, we can
662   // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
663   // a and b are divisible by m and n respectively.
664   // We then calculate how many of those bits are inferrable and set
665   // the output. For example, the i8 mul:
666   //  a = XXXX1100 (12)
667   //  b = XXXX1110 (14)
668   // We know the bottom 3 bits are zero since the first can be divided by
669   // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
670   // Applying the multiplication to the trimmed arguments gets:
671   //    XX11 (3)
672   //    X111 (7)
673   // -------
674   //    XX11
675   //   XX11
676   //  XX11
677   // XX11
678   // -------
679   // XXXXX01
680   // Which allows us to infer the 2 LSBs. Since we're multiplying the result
681   // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
682   // The proof for this can be described as:
683   // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
684   //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
685   //                    umin(countTrailingZeros(C2), C6) +
686   //                    umin(C5 - umin(countTrailingZeros(C1), C5),
687   //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
688   // %aa = shl i8 %a, C5
689   // %bb = shl i8 %b, C6
690   // %aaa = or i8 %aa, C1
691   // %bbb = or i8 %bb, C2
692   // %mul = mul i8 %aaa, %bbb
693   // %mask = and i8 %mul, C7
694   //   =>
695   // %mask = i8 ((C1*C2)&C7)
696   // Where C5, C6 describe the known bits of %a, %b
697   // C1, C2 describe the known bottom bits of %a, %b.
698   // C7 describes the mask of the known bits of the result.
699   const APInt &Bottom0 = LHS.One;
700   const APInt &Bottom1 = RHS.One;
701 
702   // How many times we'd be able to divide each argument by 2 (shr by 1).
703   // This gives us the number of trailing zeros on the multiplication result.
704   unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one();
705   unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one();
706   unsigned TrailZero0 = LHS.countMinTrailingZeros();
707   unsigned TrailZero1 = RHS.countMinTrailingZeros();
708   unsigned TrailZ = TrailZero0 + TrailZero1;
709 
710   // Figure out the fewest known-bits operand.
711   unsigned SmallestOperand =
712       std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
713   unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
714 
715   APInt BottomKnown =
716       Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
717 
718   KnownBits Res(BitWidth);
719   Res.Zero.setHighBits(LeadZ);
720   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
721   Res.One = BottomKnown.getLoBits(ResultBitsKnown);
722 
723   // If we're self-multiplying then bit[1] is guaranteed to be zero.
724   if (NoUndefSelfMultiply && BitWidth > 1) {
725     assert(Res.One[1] == 0 &&
726            "Self-multiplication failed Quadratic Reciprocity!");
727     Res.Zero.setBit(1);
728   }
729 
730   return Res;
731 }
732 
733 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
734   unsigned BitWidth = LHS.getBitWidth();
735   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
736          !RHS.hasConflict() && "Operand mismatch");
737   KnownBits WideLHS = LHS.sext(2 * BitWidth);
738   KnownBits WideRHS = RHS.sext(2 * BitWidth);
739   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
740 }
741 
742 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
743   unsigned BitWidth = LHS.getBitWidth();
744   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
745          !RHS.hasConflict() && "Operand mismatch");
746   KnownBits WideLHS = LHS.zext(2 * BitWidth);
747   KnownBits WideRHS = RHS.zext(2 * BitWidth);
748   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
749 }
750 
751 static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS,
752                                   const KnownBits &RHS, bool Exact) {
753 
754   if (!Exact)
755     return Known;
756 
757   // If LHS is Odd, the result is Odd no matter what.
758   // Odd / Odd -> Odd
759   // Odd / Even -> Impossible (because its exact division)
760   if (LHS.One[0])
761     Known.One.setBit(0);
762 
763   int MinTZ =
764       (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros();
765   int MaxTZ =
766       (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros();
767   if (MinTZ >= 0) {
768     // Result has at least MinTZ trailing zeros.
769     Known.Zero.setLowBits(MinTZ);
770     if (MinTZ == MaxTZ) {
771       // Result has exactly MinTZ trailing zeros.
772       Known.One.setBit(MinTZ);
773     }
774   } else if (MaxTZ < 0) {
775     // Poison Result
776     Known.setAllZero();
777   }
778 
779   // In the KnownBits exhaustive tests, we have poison inputs for exact values
780   // a LOT. If we have a conflict, just return all zeros.
781   if (Known.hasConflict())
782     Known.setAllZero();
783 
784   return Known;
785 }
786 
787 KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
788                           bool Exact) {
789   // Equivalent of `udiv`. We must have caught this before it was folded.
790   if (LHS.isNonNegative() && RHS.isNonNegative())
791     return udiv(LHS, RHS, Exact);
792 
793   unsigned BitWidth = LHS.getBitWidth();
794   assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
795   KnownBits Known(BitWidth);
796 
797   if (LHS.isZero() || RHS.isZero()) {
798     // Result is either known Zero or UB. Return Zero either way.
799     // Checking this earlier saves us a lot of special cases later on.
800     Known.setAllZero();
801     return Known;
802   }
803 
804   std::optional<APInt> Res;
805   if (LHS.isNegative() && RHS.isNegative()) {
806     // Result non-negative.
807     APInt Denom = RHS.getSignedMaxValue();
808     APInt Num = LHS.getSignedMinValue();
809     // INT_MIN/-1 would be a poison result (impossible). Estimate the division
810     // as signed max (we will only set sign bit in the result).
811     Res = (Num.isMinSignedValue() && Denom.isAllOnes())
812               ? APInt::getSignedMaxValue(BitWidth)
813               : Num.sdiv(Denom);
814   } else if (LHS.isNegative() && RHS.isNonNegative()) {
815     // Result is negative if Exact OR -LHS u>= RHS.
816     if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) {
817       APInt Denom = RHS.getSignedMinValue();
818       APInt Num = LHS.getSignedMinValue();
819       Res = Denom.isZero() ? Num : Num.sdiv(Denom);
820     }
821   } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
822     // Result is negative if Exact OR LHS u>= -RHS.
823     if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) {
824       APInt Denom = RHS.getSignedMaxValue();
825       APInt Num = LHS.getSignedMaxValue();
826       Res = Num.sdiv(Denom);
827     }
828   }
829 
830   if (Res) {
831     if (Res->isNonNegative()) {
832       unsigned LeadZ = Res->countLeadingZeros();
833       Known.Zero.setHighBits(LeadZ);
834     } else {
835       unsigned LeadO = Res->countLeadingOnes();
836       Known.One.setHighBits(LeadO);
837     }
838   }
839 
840   Known = divComputeLowBit(Known, LHS, RHS, Exact);
841 
842   assert(!Known.hasConflict() && "Bad Output");
843   return Known;
844 }
845 
846 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
847                           bool Exact) {
848   unsigned BitWidth = LHS.getBitWidth();
849   assert(!LHS.hasConflict() && !RHS.hasConflict());
850   KnownBits Known(BitWidth);
851 
852   if (LHS.isZero() || RHS.isZero()) {
853     // Result is either known Zero or UB. Return Zero either way.
854     // Checking this earlier saves us a lot of special cases later on.
855     Known.setAllZero();
856     return Known;
857   }
858 
859   // We can figure out the minimum number of upper zero bits by doing
860   // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
861   // gets larger, the number of upper zero bits increases.
862   APInt MinDenom = RHS.getMinValue();
863   APInt MaxNum = LHS.getMaxValue();
864   APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom);
865 
866   unsigned LeadZ = MaxRes.countLeadingZeros();
867 
868   Known.Zero.setHighBits(LeadZ);
869   Known = divComputeLowBit(Known, LHS, RHS, Exact);
870 
871   assert(!Known.hasConflict() && "Bad Output");
872   return Known;
873 }
874 
875 KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) {
876   unsigned BitWidth = LHS.getBitWidth();
877   if (!RHS.isZero() && RHS.Zero[0]) {
878     // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
879     unsigned RHSZeros = RHS.countMinTrailingZeros();
880     APInt Mask = APInt::getLowBitsSet(BitWidth, RHSZeros);
881     APInt OnesMask = LHS.One & Mask;
882     APInt ZerosMask = LHS.Zero & Mask;
883     return KnownBits(ZerosMask, OnesMask);
884   }
885   return KnownBits(BitWidth);
886 }
887 
888 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
889   assert(!LHS.hasConflict() && !RHS.hasConflict());
890 
891   KnownBits Known = remGetLowBits(LHS, RHS);
892   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
893     // NB: Low bits set in `remGetLowBits`.
894     APInt HighBits = ~(RHS.getConstant() - 1);
895     Known.Zero |= HighBits;
896     return Known;
897   }
898 
899   // Since the result is less than or equal to either operand, any leading
900   // zero bits in either operand must also exist in the result.
901   uint32_t Leaders =
902       std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros());
903   Known.Zero.setHighBits(Leaders);
904   return Known;
905 }
906 
907 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
908   assert(!LHS.hasConflict() && !RHS.hasConflict());
909 
910   KnownBits Known = remGetLowBits(LHS, RHS);
911   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
912     // NB: Low bits are set in `remGetLowBits`.
913     APInt LowBits = RHS.getConstant() - 1;
914     // If the first operand is non-negative or has all low bits zero, then
915     // the upper bits are all zero.
916     if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero))
917       Known.Zero |= ~LowBits;
918 
919     // If the first operand is negative and not all low bits are zero, then
920     // the upper bits are all one.
921     if (LHS.isNegative() && LowBits.intersects(LHS.One))
922       Known.One |= ~LowBits;
923     return Known;
924   }
925 
926   // The sign bit is the LHS's sign bit, except when the result of the
927   // remainder is zero. The magnitude of the result should be less than or
928   // equal to the magnitude of the LHS. Therefore any leading zeros that exist
929   // in the left hand side must also exist in the result.
930   Known.Zero.setHighBits(LHS.countMinLeadingZeros());
931   return Known;
932 }
933 
934 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
935   // Result bit is 0 if either operand bit is 0.
936   Zero |= RHS.Zero;
937   // Result bit is 1 if both operand bits are 1.
938   One &= RHS.One;
939   return *this;
940 }
941 
942 KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
943   // Result bit is 0 if both operand bits are 0.
944   Zero &= RHS.Zero;
945   // Result bit is 1 if either operand bit is 1.
946   One |= RHS.One;
947   return *this;
948 }
949 
950 KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
951   // Result bit is 0 if both operand bits are 0 or both are 1.
952   APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
953   // Result bit is 1 if one operand bit is 0 and the other is 1.
954   One = (Zero & RHS.One) | (One & RHS.Zero);
955   Zero = std::move(Z);
956   return *this;
957 }
958 
959 KnownBits KnownBits::blsi() const {
960   unsigned BitWidth = getBitWidth();
961   KnownBits Known(Zero, APInt(BitWidth, 0));
962   unsigned Max = countMaxTrailingZeros();
963   Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
964   unsigned Min = countMinTrailingZeros();
965   if (Max == Min && Max < BitWidth)
966     Known.One.setBit(Max);
967   return Known;
968 }
969 
970 KnownBits KnownBits::blsmsk() const {
971   unsigned BitWidth = getBitWidth();
972   KnownBits Known(BitWidth);
973   unsigned Max = countMaxTrailingZeros();
974   Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
975   unsigned Min = countMinTrailingZeros();
976   Known.One.setLowBits(std::min(Min + 1, BitWidth));
977   return Known;
978 }
979 
980 void KnownBits::print(raw_ostream &OS) const {
981   unsigned BitWidth = getBitWidth();
982   for (unsigned I = 0; I < BitWidth; ++I) {
983     unsigned N = BitWidth - I - 1;
984     if (Zero[N] && One[N])
985       OS << "!";
986     else if (Zero[N])
987       OS << "0";
988     else if (One[N])
989       OS << "1";
990     else
991       OS << "?";
992   }
993 }
994 void KnownBits::dump() const {
995   print(dbgs());
996   dbgs() << "\n";
997 }
998