1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the class that knows how to divide SCEV's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Analysis/ScalarEvolutionDivision.h"
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include <cassert>
22 #include <cstdint>
23 
24 namespace llvm {
25 class Type;
26 }
27 
28 using namespace llvm;
29 
30 namespace {
31 
sizeOfSCEV(const SCEV * S)32 static inline int sizeOfSCEV(const SCEV *S) {
33   struct FindSCEVSize {
34     int Size = 0;
35 
36     FindSCEVSize() = default;
37 
38     bool follow(const SCEV *S) {
39       ++Size;
40       // Keep looking at all operands of S.
41       return true;
42     }
43 
44     bool isDone() const { return false; }
45   };
46 
47   FindSCEVSize F;
48   SCEVTraversal<FindSCEVSize> ST(F);
49   ST.visitAll(S);
50   return F.Size;
51 }
52 
53 } // namespace
54 
55 // Computes the Quotient and Remainder of the division of Numerator by
56 // Denominator.
divide(ScalarEvolution & SE,const SCEV * Numerator,const SCEV * Denominator,const SCEV ** Quotient,const SCEV ** Remainder)57 void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
58                           const SCEV *Denominator, const SCEV **Quotient,
59                           const SCEV **Remainder) {
60   assert(Numerator && Denominator && "Uninitialized SCEV");
61 
62   SCEVDivision D(SE, Numerator, Denominator);
63 
64   // Check for the trivial case here to avoid having to check for it in the
65   // rest of the code.
66   if (Numerator == Denominator) {
67     *Quotient = D.One;
68     *Remainder = D.Zero;
69     return;
70   }
71 
72   if (Numerator->isZero()) {
73     *Quotient = D.Zero;
74     *Remainder = D.Zero;
75     return;
76   }
77 
78   // A simple case when N/1. The quotient is N.
79   if (Denominator->isOne()) {
80     *Quotient = Numerator;
81     *Remainder = D.Zero;
82     return;
83   }
84 
85   // Split the Denominator when it is a product.
86   if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
87     const SCEV *Q, *R;
88     *Quotient = Numerator;
89     for (const SCEV *Op : T->operands()) {
90       divide(SE, *Quotient, Op, &Q, &R);
91       *Quotient = Q;
92 
93       // Bail out when the Numerator is not divisible by one of the terms of
94       // the Denominator.
95       if (!R->isZero()) {
96         *Quotient = D.Zero;
97         *Remainder = Numerator;
98         return;
99       }
100     }
101     *Remainder = D.Zero;
102     return;
103   }
104 
105   D.visit(Numerator);
106   *Quotient = D.Quotient;
107   *Remainder = D.Remainder;
108 }
109 
visitConstant(const SCEVConstant * Numerator)110 void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
111   if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
112     APInt NumeratorVal = Numerator->getAPInt();
113     APInt DenominatorVal = D->getAPInt();
114     uint32_t NumeratorBW = NumeratorVal.getBitWidth();
115     uint32_t DenominatorBW = DenominatorVal.getBitWidth();
116 
117     if (NumeratorBW > DenominatorBW)
118       DenominatorVal = DenominatorVal.sext(NumeratorBW);
119     else if (NumeratorBW < DenominatorBW)
120       NumeratorVal = NumeratorVal.sext(DenominatorBW);
121 
122     APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
123     APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
124     APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
125     Quotient = SE.getConstant(QuotientVal);
126     Remainder = SE.getConstant(RemainderVal);
127     return;
128   }
129 }
130 
visitAddRecExpr(const SCEVAddRecExpr * Numerator)131 void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
132   const SCEV *StartQ, *StartR, *StepQ, *StepR;
133   if (!Numerator->isAffine())
134     return cannotDivide(Numerator);
135   divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
136   divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
137   // Bail out if the types do not match.
138   Type *Ty = Denominator->getType();
139   if (Ty != StartQ->getType() || Ty != StartR->getType() ||
140       Ty != StepQ->getType() || Ty != StepR->getType())
141     return cannotDivide(Numerator);
142   Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
143                               Numerator->getNoWrapFlags());
144   Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
145                                Numerator->getNoWrapFlags());
146 }
147 
visitAddExpr(const SCEVAddExpr * Numerator)148 void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
149   SmallVector<const SCEV *, 2> Qs, Rs;
150   Type *Ty = Denominator->getType();
151 
152   for (const SCEV *Op : Numerator->operands()) {
153     const SCEV *Q, *R;
154     divide(SE, Op, Denominator, &Q, &R);
155 
156     // Bail out if types do not match.
157     if (Ty != Q->getType() || Ty != R->getType())
158       return cannotDivide(Numerator);
159 
160     Qs.push_back(Q);
161     Rs.push_back(R);
162   }
163 
164   if (Qs.size() == 1) {
165     Quotient = Qs[0];
166     Remainder = Rs[0];
167     return;
168   }
169 
170   Quotient = SE.getAddExpr(Qs);
171   Remainder = SE.getAddExpr(Rs);
172 }
173 
visitMulExpr(const SCEVMulExpr * Numerator)174 void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
175   SmallVector<const SCEV *, 2> Qs;
176   Type *Ty = Denominator->getType();
177 
178   bool FoundDenominatorTerm = false;
179   for (const SCEV *Op : Numerator->operands()) {
180     // Bail out if types do not match.
181     if (Ty != Op->getType())
182       return cannotDivide(Numerator);
183 
184     if (FoundDenominatorTerm) {
185       Qs.push_back(Op);
186       continue;
187     }
188 
189     // Check whether Denominator divides one of the product operands.
190     const SCEV *Q, *R;
191     divide(SE, Op, Denominator, &Q, &R);
192     if (!R->isZero()) {
193       Qs.push_back(Op);
194       continue;
195     }
196 
197     // Bail out if types do not match.
198     if (Ty != Q->getType())
199       return cannotDivide(Numerator);
200 
201     FoundDenominatorTerm = true;
202     Qs.push_back(Q);
203   }
204 
205   if (FoundDenominatorTerm) {
206     Remainder = Zero;
207     if (Qs.size() == 1)
208       Quotient = Qs[0];
209     else
210       Quotient = SE.getMulExpr(Qs);
211     return;
212   }
213 
214   if (!isa<SCEVUnknown>(Denominator))
215     return cannotDivide(Numerator);
216 
217   // The Remainder is obtained by replacing Denominator by 0 in Numerator.
218   ValueToSCEVMapTy RewriteMap;
219   RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
220   Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
221 
222   if (Remainder->isZero()) {
223     // The Quotient is obtained by replacing Denominator by 1 in Numerator.
224     RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
225     Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
226     return;
227   }
228 
229   // Quotient is (Numerator - Remainder) divided by Denominator.
230   const SCEV *Q, *R;
231   const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
232   // This SCEV does not seem to simplify: fail the division here.
233   if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
234     return cannotDivide(Numerator);
235   divide(SE, Diff, Denominator, &Q, &R);
236   if (R != Zero)
237     return cannotDivide(Numerator);
238   Quotient = Q;
239 }
240 
SCEVDivision(ScalarEvolution & S,const SCEV * Numerator,const SCEV * Denominator)241 SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
242                            const SCEV *Denominator)
243     : SE(S), Denominator(Denominator) {
244   Zero = SE.getZero(Denominator->getType());
245   One = SE.getOne(Denominator->getType());
246 
247   // We generally do not know how to divide Expr by Denominator. We initialize
248   // the division to a "cannot divide" state to simplify the rest of the code.
249   cannotDivide(Numerator);
250 }
251 
252 // Convenience function for giving up on the division. We set the quotient to
253 // be equal to zero and the remainder to be equal to the numerator.
cannotDivide(const SCEV * Numerator)254 void SCEVDivision::cannotDivide(const SCEV *Numerator) {
255   Quotient = Zero;
256   Remainder = Numerator;
257 }
258