1 //===- DemandedBits.cpp - Determine demanded bits -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a demanded bits analysis. A demanded bit is one that
10 // contributes to a result; bits that are not demanded can be either zero or
11 // one without affecting control or data flow. For example in this sequence:
12 //
13 //   %1 = add i32 %x, %y
14 //   %2 = trunc i32 %1 to i16
15 //
16 // Only the lowest 16 bits of %1 are demanded; the rest are removed by the
17 // trunc.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #include "llvm/Analysis/DemandedBits.h"
22 #include "llvm/ADT/APInt.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/DerivedTypes.h"
31 #include "llvm/IR/Dominators.h"
32 #include "llvm/IR/InstIterator.h"
33 #include "llvm/IR/InstrTypes.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/IR/Intrinsics.h"
37 #include "llvm/IR/Module.h"
38 #include "llvm/IR/Operator.h"
39 #include "llvm/IR/PassManager.h"
40 #include "llvm/IR/PatternMatch.h"
41 #include "llvm/IR/Type.h"
42 #include "llvm/IR/Use.h"
43 #include "llvm/InitializePasses.h"
44 #include "llvm/Pass.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/KnownBits.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include <algorithm>
50 #include <cstdint>
51 
52 using namespace llvm;
53 using namespace llvm::PatternMatch;
54 
55 #define DEBUG_TYPE "demanded-bits"
56 
57 char DemandedBitsWrapperPass::ID = 0;
58 
59 INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits",
60                       "Demanded bits analysis", false, false)
61 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
62 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
63 INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits",
64                     "Demanded bits analysis", false, false)
65 
66 DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) {
67   initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry());
68 }
69 
70 void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
71   AU.setPreservesCFG();
72   AU.addRequired<AssumptionCacheTracker>();
73   AU.addRequired<DominatorTreeWrapperPass>();
74   AU.setPreservesAll();
75 }
76 
77 void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const {
78   DB->print(OS);
79 }
80 
81 static bool isAlwaysLive(Instruction *I) {
82   return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() ||
83          I->mayHaveSideEffects() || !I->willReturn();
84 }
85 
86 void DemandedBits::determineLiveOperandBits(
87     const Instruction *UserI, const Value *Val, unsigned OperandNo,
88     const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2,
89     bool &KnownBitsComputed) {
90   unsigned BitWidth = AB.getBitWidth();
91 
92   // We're called once per operand, but for some instructions, we need to
93   // compute known bits of both operands in order to determine the live bits of
94   // either (when both operands are instructions themselves). We don't,
95   // however, want to do this twice, so we cache the result in APInts that live
96   // in the caller. For the two-relevant-operands case, both operand values are
97   // provided here.
98   auto ComputeKnownBits =
99       [&](unsigned BitWidth, const Value *V1, const Value *V2) {
100         if (KnownBitsComputed)
101           return;
102         KnownBitsComputed = true;
103 
104         const DataLayout &DL = UserI->getModule()->getDataLayout();
105         Known = KnownBits(BitWidth);
106         computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT);
107 
108         if (V2) {
109           Known2 = KnownBits(BitWidth);
110           computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT);
111         }
112       };
113 
114   switch (UserI->getOpcode()) {
115   default: break;
116   case Instruction::Call:
117   case Instruction::Invoke:
118     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI)) {
119       switch (II->getIntrinsicID()) {
120       default: break;
121       case Intrinsic::bswap:
122         // The alive bits of the input are the swapped alive bits of
123         // the output.
124         AB = AOut.byteSwap();
125         break;
126       case Intrinsic::bitreverse:
127         // The alive bits of the input are the reversed alive bits of
128         // the output.
129         AB = AOut.reverseBits();
130         break;
131       case Intrinsic::ctlz:
132         if (OperandNo == 0) {
133           // We need some output bits, so we need all bits of the
134           // input to the left of, and including, the leftmost bit
135           // known to be one.
136           ComputeKnownBits(BitWidth, Val, nullptr);
137           AB = APInt::getHighBitsSet(BitWidth,
138                  std::min(BitWidth, Known.countMaxLeadingZeros()+1));
139         }
140         break;
141       case Intrinsic::cttz:
142         if (OperandNo == 0) {
143           // We need some output bits, so we need all bits of the
144           // input to the right of, and including, the rightmost bit
145           // known to be one.
146           ComputeKnownBits(BitWidth, Val, nullptr);
147           AB = APInt::getLowBitsSet(BitWidth,
148                  std::min(BitWidth, Known.countMaxTrailingZeros()+1));
149         }
150         break;
151       case Intrinsic::fshl:
152       case Intrinsic::fshr: {
153         const APInt *SA;
154         if (OperandNo == 2) {
155           // Shift amount is modulo the bitwidth. For powers of two we have
156           // SA % BW == SA & (BW - 1).
157           if (isPowerOf2_32(BitWidth))
158             AB = BitWidth - 1;
159         } else if (match(II->getOperand(2), m_APInt(SA))) {
160           // Normalize to funnel shift left. APInt shifts of BitWidth are well-
161           // defined, so no need to special-case zero shifts here.
162           uint64_t ShiftAmt = SA->urem(BitWidth);
163           if (II->getIntrinsicID() == Intrinsic::fshr)
164             ShiftAmt = BitWidth - ShiftAmt;
165 
166           if (OperandNo == 0)
167             AB = AOut.lshr(ShiftAmt);
168           else if (OperandNo == 1)
169             AB = AOut.shl(BitWidth - ShiftAmt);
170         }
171         break;
172       }
173       case Intrinsic::umax:
174       case Intrinsic::umin:
175       case Intrinsic::smax:
176       case Intrinsic::smin:
177         // If low bits of result are not demanded, they are also not demanded
178         // for the min/max operands.
179         AB = APInt::getBitsSetFrom(BitWidth, AOut.countTrailingZeros());
180         break;
181       }
182     }
183     break;
184   case Instruction::Add:
185     if (AOut.isMask()) {
186       AB = AOut;
187     } else {
188       ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
189       AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
190     }
191     break;
192   case Instruction::Sub:
193     if (AOut.isMask()) {
194       AB = AOut;
195     } else {
196       ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
197       AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
198     }
199     break;
200   case Instruction::Mul:
201     // Find the highest live output bit. We don't need any more input
202     // bits than that (adds, and thus subtracts, ripple only to the
203     // left).
204     AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
205     break;
206   case Instruction::Shl:
207     if (OperandNo == 0) {
208       const APInt *ShiftAmtC;
209       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
210         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
211         AB = AOut.lshr(ShiftAmt);
212 
213         // If the shift is nuw/nsw, then the high bits are not dead
214         // (because we've promised that they *must* be zero).
215         const ShlOperator *S = cast<ShlOperator>(UserI);
216         if (S->hasNoSignedWrap())
217           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
218         else if (S->hasNoUnsignedWrap())
219           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
220       }
221     }
222     break;
223   case Instruction::LShr:
224     if (OperandNo == 0) {
225       const APInt *ShiftAmtC;
226       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
227         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
228         AB = AOut.shl(ShiftAmt);
229 
230         // If the shift is exact, then the low bits are not dead
231         // (they must be zero).
232         if (cast<LShrOperator>(UserI)->isExact())
233           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
234       }
235     }
236     break;
237   case Instruction::AShr:
238     if (OperandNo == 0) {
239       const APInt *ShiftAmtC;
240       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
241         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
242         AB = AOut.shl(ShiftAmt);
243         // Because the high input bit is replicated into the
244         // high-order bits of the result, if we need any of those
245         // bits, then we must keep the highest input bit.
246         if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
247             .getBoolValue())
248           AB.setSignBit();
249 
250         // If the shift is exact, then the low bits are not dead
251         // (they must be zero).
252         if (cast<AShrOperator>(UserI)->isExact())
253           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
254       }
255     }
256     break;
257   case Instruction::And:
258     AB = AOut;
259 
260     // For bits that are known zero, the corresponding bits in the
261     // other operand are dead (unless they're both zero, in which
262     // case they can't both be dead, so just mark the LHS bits as
263     // dead).
264     ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
265     if (OperandNo == 0)
266       AB &= ~Known2.Zero;
267     else
268       AB &= ~(Known.Zero & ~Known2.Zero);
269     break;
270   case Instruction::Or:
271     AB = AOut;
272 
273     // For bits that are known one, the corresponding bits in the
274     // other operand are dead (unless they're both one, in which
275     // case they can't both be dead, so just mark the LHS bits as
276     // dead).
277     ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
278     if (OperandNo == 0)
279       AB &= ~Known2.One;
280     else
281       AB &= ~(Known.One & ~Known2.One);
282     break;
283   case Instruction::Xor:
284   case Instruction::PHI:
285     AB = AOut;
286     break;
287   case Instruction::Trunc:
288     AB = AOut.zext(BitWidth);
289     break;
290   case Instruction::ZExt:
291     AB = AOut.trunc(BitWidth);
292     break;
293   case Instruction::SExt:
294     AB = AOut.trunc(BitWidth);
295     // Because the high input bit is replicated into the
296     // high-order bits of the result, if we need any of those
297     // bits, then we must keep the highest input bit.
298     if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
299                                       AOut.getBitWidth() - BitWidth))
300         .getBoolValue())
301       AB.setSignBit();
302     break;
303   case Instruction::Select:
304     if (OperandNo != 0)
305       AB = AOut;
306     break;
307   case Instruction::ExtractElement:
308     if (OperandNo == 0)
309       AB = AOut;
310     break;
311   case Instruction::InsertElement:
312   case Instruction::ShuffleVector:
313     if (OperandNo == 0 || OperandNo == 1)
314       AB = AOut;
315     break;
316   }
317 }
318 
319 bool DemandedBitsWrapperPass::runOnFunction(Function &F) {
320   auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
321   auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
322   DB.emplace(F, AC, DT);
323   return false;
324 }
325 
326 void DemandedBitsWrapperPass::releaseMemory() {
327   DB.reset();
328 }
329 
330 void DemandedBits::performAnalysis() {
331   if (Analyzed)
332     // Analysis already completed for this function.
333     return;
334   Analyzed = true;
335 
336   Visited.clear();
337   AliveBits.clear();
338   DeadUses.clear();
339 
340   SmallSetVector<Instruction*, 16> Worklist;
341 
342   // Collect the set of "root" instructions that are known live.
343   for (Instruction &I : instructions(F)) {
344     if (!isAlwaysLive(&I))
345       continue;
346 
347     LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
348     // For integer-valued instructions, set up an initial empty set of alive
349     // bits and add the instruction to the work list. For other instructions
350     // add their operands to the work list (for integer values operands, mark
351     // all bits as live).
352     Type *T = I.getType();
353     if (T->isIntOrIntVectorTy()) {
354       if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
355         Worklist.insert(&I);
356 
357       continue;
358     }
359 
360     // Non-integer-typed instructions...
361     for (Use &OI : I.operands()) {
362       if (Instruction *J = dyn_cast<Instruction>(OI)) {
363         Type *T = J->getType();
364         if (T->isIntOrIntVectorTy())
365           AliveBits[J] = APInt::getAllOnesValue(T->getScalarSizeInBits());
366         else
367           Visited.insert(J);
368         Worklist.insert(J);
369       }
370     }
371     // To save memory, we don't add I to the Visited set here. Instead, we
372     // check isAlwaysLive on every instruction when searching for dead
373     // instructions later (we need to check isAlwaysLive for the
374     // integer-typed instructions anyway).
375   }
376 
377   // Propagate liveness backwards to operands.
378   while (!Worklist.empty()) {
379     Instruction *UserI = Worklist.pop_back_val();
380 
381     LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
382     APInt AOut;
383     bool InputIsKnownDead = false;
384     if (UserI->getType()->isIntOrIntVectorTy()) {
385       AOut = AliveBits[UserI];
386       LLVM_DEBUG(dbgs() << " Alive Out: 0x"
387                         << Twine::utohexstr(AOut.getLimitedValue()));
388 
389       // If all bits of the output are dead, then all bits of the input
390       // are also dead.
391       InputIsKnownDead = !AOut && !isAlwaysLive(UserI);
392     }
393     LLVM_DEBUG(dbgs() << "\n");
394 
395     KnownBits Known, Known2;
396     bool KnownBitsComputed = false;
397     // Compute the set of alive bits for each operand. These are anded into the
398     // existing set, if any, and if that changes the set of alive bits, the
399     // operand is added to the work-list.
400     for (Use &OI : UserI->operands()) {
401       // We also want to detect dead uses of arguments, but will only store
402       // demanded bits for instructions.
403       Instruction *I = dyn_cast<Instruction>(OI);
404       if (!I && !isa<Argument>(OI))
405         continue;
406 
407       Type *T = OI->getType();
408       if (T->isIntOrIntVectorTy()) {
409         unsigned BitWidth = T->getScalarSizeInBits();
410         APInt AB = APInt::getAllOnesValue(BitWidth);
411         if (InputIsKnownDead) {
412           AB = APInt(BitWidth, 0);
413         } else {
414           // Bits of each operand that are used to compute alive bits of the
415           // output are alive, all others are dead.
416           determineLiveOperandBits(UserI, OI, OI.getOperandNo(), AOut, AB,
417                                    Known, Known2, KnownBitsComputed);
418 
419           // Keep track of uses which have no demanded bits.
420           if (AB.isNullValue())
421             DeadUses.insert(&OI);
422           else
423             DeadUses.erase(&OI);
424         }
425 
426         if (I) {
427           // If we've added to the set of alive bits (or the operand has not
428           // been previously visited), then re-queue the operand to be visited
429           // again.
430           auto Res = AliveBits.try_emplace(I);
431           if (Res.second || (AB |= Res.first->second) != Res.first->second) {
432             Res.first->second = std::move(AB);
433             Worklist.insert(I);
434           }
435         }
436       } else if (I && Visited.insert(I).second) {
437         Worklist.insert(I);
438       }
439     }
440   }
441 }
442 
443 APInt DemandedBits::getDemandedBits(Instruction *I) {
444   performAnalysis();
445 
446   auto Found = AliveBits.find(I);
447   if (Found != AliveBits.end())
448     return Found->second;
449 
450   const DataLayout &DL = I->getModule()->getDataLayout();
451   return APInt::getAllOnesValue(
452       DL.getTypeSizeInBits(I->getType()->getScalarType()));
453 }
454 
455 bool DemandedBits::isInstructionDead(Instruction *I) {
456   performAnalysis();
457 
458   return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() &&
459     !isAlwaysLive(I);
460 }
461 
462 bool DemandedBits::isUseDead(Use *U) {
463   // We only track integer uses, everything else is assumed live.
464   if (!(*U)->getType()->isIntOrIntVectorTy())
465     return false;
466 
467   // Uses by always-live instructions are never dead.
468   Instruction *UserI = cast<Instruction>(U->getUser());
469   if (isAlwaysLive(UserI))
470     return false;
471 
472   performAnalysis();
473   if (DeadUses.count(U))
474     return true;
475 
476   // If no output bits are demanded, no input bits are demanded and the use
477   // is dead. These uses might not be explicitly present in the DeadUses map.
478   if (UserI->getType()->isIntOrIntVectorTy()) {
479     auto Found = AliveBits.find(UserI);
480     if (Found != AliveBits.end() && Found->second.isNullValue())
481       return true;
482   }
483 
484   return false;
485 }
486 
487 void DemandedBits::print(raw_ostream &OS) {
488   performAnalysis();
489   for (auto &KV : AliveBits) {
490     OS << "DemandedBits: 0x" << Twine::utohexstr(KV.second.getLimitedValue())
491        << " for " << *KV.first << '\n';
492   }
493 }
494 
495 static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
496                                               const APInt &AOut,
497                                               const KnownBits &LHS,
498                                               const KnownBits &RHS,
499                                               bool CarryZero, bool CarryOne) {
500   assert(!(CarryZero && CarryOne) &&
501          "Carry can't be zero and one at the same time");
502 
503   // The following check should be done by the caller, as it also indicates
504   // that LHS and RHS don't need to be computed.
505   //
506   // if (AOut.isMask())
507   //   return AOut;
508 
509   // Boundary bits' carry out is unaffected by their carry in.
510   APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
511 
512   // First, the alive carry bits are determined from the alive output bits:
513   // Let demand ripple to the right but only up to any set bit in Bound.
514   //   AOut         = -1----
515   //   Bound        = ----1-
516   //   ACarry&~AOut = --111-
517   APInt RBound = Bound.reverseBits();
518   APInt RAOut = AOut.reverseBits();
519   APInt RProp = RAOut + (RAOut | ~RBound);
520   APInt RACarry = RProp ^ ~RBound;
521   APInt ACarry = RACarry.reverseBits();
522 
523   // Then, the alive input bits are determined from the alive carry bits:
524   APInt NeededToMaintainCarryZero;
525   APInt NeededToMaintainCarryOne;
526   if (OperandNo == 0) {
527     NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
528     NeededToMaintainCarryOne = LHS.One | ~RHS.One;
529   } else {
530     NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
531     NeededToMaintainCarryOne = RHS.One | ~LHS.One;
532   }
533 
534   // As in computeForAddCarry
535   APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
536   APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
537 
538   // The below is simplified from
539   //
540   // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
541   // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
542   // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
543   //
544   // APInt NeededToMaintainCarry =
545   //   (CarryKnownZero & NeededToMaintainCarryZero) |
546   //   (CarryKnownOne  & NeededToMaintainCarryOne) |
547   //   CarryUnknown;
548 
549   APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
550                                 (PossibleSumOne | NeededToMaintainCarryOne);
551 
552   APInt AB = AOut | (ACarry & NeededToMaintainCarry);
553   return AB;
554 }
555 
556 APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
557                                                 const APInt &AOut,
558                                                 const KnownBits &LHS,
559                                                 const KnownBits &RHS) {
560   return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
561                                           false);
562 }
563 
564 APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
565                                                 const APInt &AOut,
566                                                 const KnownBits &LHS,
567                                                 const KnownBits &RHS) {
568   KnownBits NRHS;
569   NRHS.Zero = RHS.One;
570   NRHS.One = RHS.Zero;
571   return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
572                                           true);
573 }
574 
575 FunctionPass *llvm::createDemandedBitsWrapperPass() {
576   return new DemandedBitsWrapperPass();
577 }
578 
579 AnalysisKey DemandedBitsAnalysis::Key;
580 
581 DemandedBits DemandedBitsAnalysis::run(Function &F,
582                                              FunctionAnalysisManager &AM) {
583   auto &AC = AM.getResult<AssumptionAnalysis>(F);
584   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
585   return DemandedBits(F, AC, DT);
586 }
587 
588 PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
589                                                FunctionAnalysisManager &AM) {
590   AM.getResult<DemandedBitsAnalysis>(F).print(OS);
591   return PreservedAnalyses::all();
592 }
593