1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/FoldingSet.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/ScalarEvolution.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/IR/ValueHandle.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include <cassert>
28 #include <cstddef>
29 
30 namespace llvm {
31 
32 class APInt;
33 class Constant;
34 class ConstantRange;
35 class Loop;
36 class Type;
37 
38   enum SCEVTypes : unsigned short {
39     // These should be ordered in terms of increasing complexity to make the
40     // folders simpler.
41     scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
42     scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUMinExpr, scSMinExpr,
43     scPtrToInt, scUnknown, scCouldNotCompute
44   };
45 
46   /// This class represents a constant integer value.
47   class SCEVConstant : public SCEV {
48     friend class ScalarEvolution;
49 
50     ConstantInt *V;
51 
SCEVConstant(const FoldingSetNodeIDRef ID,ConstantInt * v)52     SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) :
53       SCEV(ID, scConstant, 1), V(v) {}
54 
55   public:
getValue()56     ConstantInt *getValue() const { return V; }
getAPInt()57     const APInt &getAPInt() const { return getValue()->getValue(); }
58 
getType()59     Type *getType() const { return V->getType(); }
60 
61     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)62     static bool classof(const SCEV *S) {
63       return S->getSCEVType() == scConstant;
64     }
65   };
66 
computeExpressionSize(ArrayRef<const SCEV * > Args)67   inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
68     APInt Size(16, 1);
69     for (auto *Arg : Args)
70       Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
71     return (unsigned short)Size.getZExtValue();
72   }
73 
74   /// This is the base class for unary cast operator classes.
75   class SCEVCastExpr : public SCEV {
76   protected:
77     std::array<const SCEV *, 1> Operands;
78     Type *Ty;
79 
80     SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
81                  Type *ty);
82 
83   public:
getOperand()84     const SCEV *getOperand() const { return Operands[0]; }
getOperand(unsigned i)85     const SCEV *getOperand(unsigned i) const {
86       assert(i == 0 && "Operand index out of range!");
87       return Operands[0];
88     }
89     using op_iterator = std::array<const SCEV *, 1>::const_iterator;
90     using op_range = iterator_range<op_iterator>;
91 
operands()92     op_range operands() const {
93       return make_range(Operands.begin(), Operands.end());
94     }
getNumOperands()95     size_t getNumOperands() const { return 1; }
getType()96     Type *getType() const { return Ty; }
97 
98     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)99     static bool classof(const SCEV *S) {
100       return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
101              S->getSCEVType() == scZeroExtend ||
102              S->getSCEVType() == scSignExtend;
103     }
104   };
105 
106   /// This class represents a cast from a pointer to a pointer-sized integer
107   /// value.
108   class SCEVPtrToIntExpr : public SCEVCastExpr {
109     friend class ScalarEvolution;
110 
111     SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
112 
113   public:
114     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)115     static bool classof(const SCEV *S) {
116       return S->getSCEVType() == scPtrToInt;
117     }
118   };
119 
120   /// This is the base class for unary integral cast operator classes.
121   class SCEVIntegralCastExpr : public SCEVCastExpr {
122   protected:
123     SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
124                          const SCEV *op, Type *ty);
125 
126   public:
127     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)128     static bool classof(const SCEV *S) {
129       return S->getSCEVType() == scTruncate ||
130              S->getSCEVType() == scZeroExtend ||
131              S->getSCEVType() == scSignExtend;
132     }
133   };
134 
135   /// This class represents a truncation of an integer value to a
136   /// smaller integer value.
137   class SCEVTruncateExpr : public SCEVIntegralCastExpr {
138     friend class ScalarEvolution;
139 
140     SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
141                      const SCEV *op, Type *ty);
142 
143   public:
144     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)145     static bool classof(const SCEV *S) {
146       return S->getSCEVType() == scTruncate;
147     }
148   };
149 
150   /// This class represents a zero extension of a small integer value
151   /// to a larger integer value.
152   class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
153     friend class ScalarEvolution;
154 
155     SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
156                        const SCEV *op, Type *ty);
157 
158   public:
159     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)160     static bool classof(const SCEV *S) {
161       return S->getSCEVType() == scZeroExtend;
162     }
163   };
164 
165   /// This class represents a sign extension of a small integer value
166   /// to a larger integer value.
167   class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
168     friend class ScalarEvolution;
169 
170     SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
171                        const SCEV *op, Type *ty);
172 
173   public:
174     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)175     static bool classof(const SCEV *S) {
176       return S->getSCEVType() == scSignExtend;
177     }
178   };
179 
180   /// This node is a base class providing common functionality for
181   /// n'ary operators.
182   class SCEVNAryExpr : public SCEV {
183   protected:
184     // Since SCEVs are immutable, ScalarEvolution allocates operand
185     // arrays with its SCEVAllocator, so this class just needs a simple
186     // pointer rather than a more elaborate vector-like data structure.
187     // This also avoids the need for a non-trivial destructor.
188     const SCEV *const *Operands;
189     size_t NumOperands;
190 
SCEVNAryExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)191     SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
192                  const SCEV *const *O, size_t N)
193         : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O),
194           NumOperands(N) {}
195 
196   public:
getNumOperands()197     size_t getNumOperands() const { return NumOperands; }
198 
getOperand(unsigned i)199     const SCEV *getOperand(unsigned i) const {
200       assert(i < NumOperands && "Operand index out of range!");
201       return Operands[i];
202     }
203 
204     using op_iterator = const SCEV *const *;
205     using op_range = iterator_range<op_iterator>;
206 
op_begin()207     op_iterator op_begin() const { return Operands; }
op_end()208     op_iterator op_end() const { return Operands + NumOperands; }
operands()209     op_range operands() const {
210       return make_range(op_begin(), op_end());
211     }
212 
213     NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
214       return (NoWrapFlags)(SubclassData & Mask);
215     }
216 
hasNoUnsignedWrap()217     bool hasNoUnsignedWrap() const {
218       return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
219     }
220 
hasNoSignedWrap()221     bool hasNoSignedWrap() const {
222       return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
223     }
224 
hasNoSelfWrap()225     bool hasNoSelfWrap() const {
226       return getNoWrapFlags(FlagNW) != FlagAnyWrap;
227     }
228 
229     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)230     static bool classof(const SCEV *S) {
231       return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
232              S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
233              S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
234              S->getSCEVType() == scAddRecExpr;
235     }
236   };
237 
238   /// This node is the base class for n'ary commutative operators.
239   class SCEVCommutativeExpr : public SCEVNAryExpr {
240   protected:
SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)241     SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,
242                         enum SCEVTypes T, const SCEV *const *O, size_t N)
243       : SCEVNAryExpr(ID, T, O, N) {}
244 
245   public:
246     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)247     static bool classof(const SCEV *S) {
248       return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
249              S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
250              S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
251     }
252 
253     /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)254     void setNoWrapFlags(NoWrapFlags Flags) {
255       SubclassData |= Flags;
256     }
257   };
258 
259   /// This node represents an addition of some number of SCEVs.
260   class SCEVAddExpr : public SCEVCommutativeExpr {
261     friend class ScalarEvolution;
262 
263     Type *Ty;
264 
SCEVAddExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)265     SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
266         : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
267       auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
268         return Op->getType()->isPointerTy();
269       });
270       if (FirstPointerTypedOp != operands().end())
271         Ty = (*FirstPointerTypedOp)->getType();
272       else
273         Ty = getOperand(0)->getType();
274     }
275 
276   public:
getType()277     Type *getType() const { return Ty; }
278 
279     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)280     static bool classof(const SCEV *S) {
281       return S->getSCEVType() == scAddExpr;
282     }
283   };
284 
285   /// This node represents multiplication of some number of SCEVs.
286   class SCEVMulExpr : public SCEVCommutativeExpr {
287     friend class ScalarEvolution;
288 
SCEVMulExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)289     SCEVMulExpr(const FoldingSetNodeIDRef ID,
290                 const SCEV *const *O, size_t N)
291       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
292 
293   public:
getType()294     Type *getType() const { return getOperand(0)->getType(); }
295 
296     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)297     static bool classof(const SCEV *S) {
298       return S->getSCEVType() == scMulExpr;
299     }
300   };
301 
302   /// This class represents a binary unsigned division operation.
303   class SCEVUDivExpr : public SCEV {
304     friend class ScalarEvolution;
305 
306     std::array<const SCEV *, 2> Operands;
307 
SCEVUDivExpr(const FoldingSetNodeIDRef ID,const SCEV * lhs,const SCEV * rhs)308     SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
309         : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
310         Operands[0] = lhs;
311         Operands[1] = rhs;
312       }
313 
314   public:
getLHS()315     const SCEV *getLHS() const { return Operands[0]; }
getRHS()316     const SCEV *getRHS() const { return Operands[1]; }
getNumOperands()317     size_t getNumOperands() const { return 2; }
getOperand(unsigned i)318     const SCEV *getOperand(unsigned i) const {
319       assert((i == 0 || i == 1) && "Operand index out of range!");
320       return i == 0 ? getLHS() : getRHS();
321     }
322 
323     using op_iterator = std::array<const SCEV *, 2>::const_iterator;
324     using op_range = iterator_range<op_iterator>;
operands()325     op_range operands() const {
326       return make_range(Operands.begin(), Operands.end());
327     }
328 
getType()329     Type *getType() const {
330       // In most cases the types of LHS and RHS will be the same, but in some
331       // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
332       // depend on the type for correctness, but handling types carefully can
333       // avoid extra casts in the SCEVExpander. The LHS is more likely to be
334       // a pointer type than the RHS, so use the RHS' type here.
335       return getRHS()->getType();
336     }
337 
338     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)339     static bool classof(const SCEV *S) {
340       return S->getSCEVType() == scUDivExpr;
341     }
342   };
343 
344   /// This node represents a polynomial recurrence on the trip count
345   /// of the specified loop.  This is the primary focus of the
346   /// ScalarEvolution framework; all the other SCEV subclasses are
347   /// mostly just supporting infrastructure to allow SCEVAddRecExpr
348   /// expressions to be created and analyzed.
349   ///
350   /// All operands of an AddRec are required to be loop invariant.
351   ///
352   class SCEVAddRecExpr : public SCEVNAryExpr {
353     friend class ScalarEvolution;
354 
355     const Loop *L;
356 
SCEVAddRecExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N,const Loop * l)357     SCEVAddRecExpr(const FoldingSetNodeIDRef ID,
358                    const SCEV *const *O, size_t N, const Loop *l)
359       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
360 
361   public:
getType()362     Type *getType() const { return getStart()->getType(); }
getStart()363     const SCEV *getStart() const { return Operands[0]; }
getLoop()364     const Loop *getLoop() const { return L; }
365 
366     /// Constructs and returns the recurrence indicating how much this
367     /// expression steps by.  If this is a polynomial of degree N, it
368     /// returns a chrec of degree N-1.  We cannot determine whether
369     /// the step recurrence has self-wraparound.
getStepRecurrence(ScalarEvolution & SE)370     const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
371       if (isAffine()) return getOperand(1);
372       return SE.getAddRecExpr(SmallVector<const SCEV *, 3>(op_begin()+1,
373                                                            op_end()),
374                               getLoop(), FlagAnyWrap);
375     }
376 
377     /// Return true if this represents an expression A + B*x where A
378     /// and B are loop invariant values.
isAffine()379     bool isAffine() const {
380       // We know that the start value is invariant.  This expression is thus
381       // affine iff the step is also invariant.
382       return getNumOperands() == 2;
383     }
384 
385     /// Return true if this represents an expression A + B*x + C*x^2
386     /// where A, B and C are loop invariant values.  This corresponds
387     /// to an addrec of the form {L,+,M,+,N}
isQuadratic()388     bool isQuadratic() const {
389       return getNumOperands() == 3;
390     }
391 
392     /// Set flags for a recurrence without clearing any previously set flags.
393     /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
394     /// to make it easier to propagate flags.
setNoWrapFlags(NoWrapFlags Flags)395     void setNoWrapFlags(NoWrapFlags Flags) {
396       if (Flags & (FlagNUW | FlagNSW))
397         Flags = ScalarEvolution::setFlags(Flags, FlagNW);
398       SubclassData |= Flags;
399     }
400 
401     /// Return the value of this chain of recurrences at the specified
402     /// iteration number.
403     const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
404 
405     /// Return the value of this chain of recurrences at the specified iteration
406     /// number. Takes an explicit list of operands to represent an AddRec.
407     static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
408                                            const SCEV *It, ScalarEvolution &SE);
409 
410     /// Return the number of iterations of this loop that produce
411     /// values in the specified constant range.  Another way of
412     /// looking at this is that it returns the first iteration number
413     /// where the value is not in the condition, thus computing the
414     /// exit count.  If the iteration count can't be computed, an
415     /// instance of SCEVCouldNotCompute is returned.
416     const SCEV *getNumIterationsInRange(const ConstantRange &Range,
417                                         ScalarEvolution &SE) const;
418 
419     /// Return an expression representing the value of this expression
420     /// one iteration of the loop ahead.
421     const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
422 
423     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)424     static bool classof(const SCEV *S) {
425       return S->getSCEVType() == scAddRecExpr;
426     }
427   };
428 
429   /// This node is the base class min/max selections.
430   class SCEVMinMaxExpr : public SCEVCommutativeExpr {
431     friend class ScalarEvolution;
432 
isMinMaxType(enum SCEVTypes T)433     static bool isMinMaxType(enum SCEVTypes T) {
434       return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
435              T == scUMinExpr;
436     }
437 
438   protected:
439     /// Note: Constructing subclasses via this constructor is allowed
SCEVMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)440     SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
441                    const SCEV *const *O, size_t N)
442         : SCEVCommutativeExpr(ID, T, O, N) {
443       assert(isMinMaxType(T));
444       // Min and max never overflow
445       setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
446     }
447 
448   public:
getType()449     Type *getType() const { return getOperand(0)->getType(); }
450 
classof(const SCEV * S)451     static bool classof(const SCEV *S) {
452       return isMinMaxType(S->getSCEVType());
453     }
454 
negate(enum SCEVTypes T)455     static enum SCEVTypes negate(enum SCEVTypes T) {
456       switch (T) {
457       case scSMaxExpr:
458         return scSMinExpr;
459       case scSMinExpr:
460         return scSMaxExpr;
461       case scUMaxExpr:
462         return scUMinExpr;
463       case scUMinExpr:
464         return scUMaxExpr;
465       default:
466         llvm_unreachable("Not a min or max SCEV type!");
467       }
468     }
469   };
470 
471   /// This class represents a signed maximum selection.
472   class SCEVSMaxExpr : public SCEVMinMaxExpr {
473     friend class ScalarEvolution;
474 
SCEVSMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)475     SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
476         : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
477 
478   public:
479     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)480     static bool classof(const SCEV *S) {
481       return S->getSCEVType() == scSMaxExpr;
482     }
483   };
484 
485   /// This class represents an unsigned maximum selection.
486   class SCEVUMaxExpr : public SCEVMinMaxExpr {
487     friend class ScalarEvolution;
488 
SCEVUMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)489     SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
490         : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
491 
492   public:
493     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)494     static bool classof(const SCEV *S) {
495       return S->getSCEVType() == scUMaxExpr;
496     }
497   };
498 
499   /// This class represents a signed minimum selection.
500   class SCEVSMinExpr : public SCEVMinMaxExpr {
501     friend class ScalarEvolution;
502 
SCEVSMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)503     SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
504         : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
505 
506   public:
507     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)508     static bool classof(const SCEV *S) {
509       return S->getSCEVType() == scSMinExpr;
510     }
511   };
512 
513   /// This class represents an unsigned minimum selection.
514   class SCEVUMinExpr : public SCEVMinMaxExpr {
515     friend class ScalarEvolution;
516 
SCEVUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)517     SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
518         : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
519 
520   public:
521     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)522     static bool classof(const SCEV *S) {
523       return S->getSCEVType() == scUMinExpr;
524     }
525   };
526 
527   /// This means that we are dealing with an entirely unknown SCEV
528   /// value, and only represent it as its LLVM Value.  This is the
529   /// "bottom" value for the analysis.
530   class SCEVUnknown final : public SCEV, private CallbackVH {
531     friend class ScalarEvolution;
532 
533     /// The parent ScalarEvolution value. This is used to update the
534     /// parent's maps when the value associated with a SCEVUnknown is
535     /// deleted or RAUW'd.
536     ScalarEvolution *SE;
537 
538     /// The next pointer in the linked list of all SCEVUnknown
539     /// instances owned by a ScalarEvolution.
540     SCEVUnknown *Next;
541 
SCEVUnknown(const FoldingSetNodeIDRef ID,Value * V,ScalarEvolution * se,SCEVUnknown * next)542     SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V,
543                 ScalarEvolution *se, SCEVUnknown *next) :
544       SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
545 
546     // Implement CallbackVH.
547     void deleted() override;
548     void allUsesReplacedWith(Value *New) override;
549 
550   public:
getValue()551     Value *getValue() const { return getValPtr(); }
552 
553     /// @{
554     /// Test whether this is a special constant representing a type
555     /// size, alignment, or field offset in a target-independent
556     /// manner, and hasn't happened to have been folded with other
557     /// operations into something unrecognizable. This is mainly only
558     /// useful for pretty-printing and other situations where it isn't
559     /// absolutely required for these to succeed.
560     bool isSizeOf(Type *&AllocTy) const;
561     bool isAlignOf(Type *&AllocTy) const;
562     bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
563     /// @}
564 
getType()565     Type *getType() const { return getValPtr()->getType(); }
566 
567     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)568     static bool classof(const SCEV *S) {
569       return S->getSCEVType() == scUnknown;
570     }
571   };
572 
573   /// This class defines a simple visitor class that may be used for
574   /// various SCEV analysis purposes.
575   template<typename SC, typename RetVal=void>
576   struct SCEVVisitor {
visitSCEVVisitor577     RetVal visit(const SCEV *S) {
578       switch (S->getSCEVType()) {
579       case scConstant:
580         return ((SC*)this)->visitConstant((const SCEVConstant*)S);
581       case scPtrToInt:
582         return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
583       case scTruncate:
584         return ((SC*)this)->visitTruncateExpr((const SCEVTruncateExpr*)S);
585       case scZeroExtend:
586         return ((SC*)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr*)S);
587       case scSignExtend:
588         return ((SC*)this)->visitSignExtendExpr((const SCEVSignExtendExpr*)S);
589       case scAddExpr:
590         return ((SC*)this)->visitAddExpr((const SCEVAddExpr*)S);
591       case scMulExpr:
592         return ((SC*)this)->visitMulExpr((const SCEVMulExpr*)S);
593       case scUDivExpr:
594         return ((SC*)this)->visitUDivExpr((const SCEVUDivExpr*)S);
595       case scAddRecExpr:
596         return ((SC*)this)->visitAddRecExpr((const SCEVAddRecExpr*)S);
597       case scSMaxExpr:
598         return ((SC*)this)->visitSMaxExpr((const SCEVSMaxExpr*)S);
599       case scUMaxExpr:
600         return ((SC*)this)->visitUMaxExpr((const SCEVUMaxExpr*)S);
601       case scSMinExpr:
602         return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
603       case scUMinExpr:
604         return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
605       case scUnknown:
606         return ((SC*)this)->visitUnknown((const SCEVUnknown*)S);
607       case scCouldNotCompute:
608         return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S);
609       }
610       llvm_unreachable("Unknown SCEV kind!");
611     }
612 
visitCouldNotComputeSCEVVisitor613     RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
614       llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
615     }
616   };
617 
618   /// Visit all nodes in the expression tree using worklist traversal.
619   ///
620   /// Visitor implements:
621   ///   // return true to follow this node.
622   ///   bool follow(const SCEV *S);
623   ///   // return true to terminate the search.
624   ///   bool isDone();
625   template<typename SV>
626   class SCEVTraversal {
627     SV &Visitor;
628     SmallVector<const SCEV *, 8> Worklist;
629     SmallPtrSet<const SCEV *, 8> Visited;
630 
push(const SCEV * S)631     void push(const SCEV *S) {
632       if (Visited.insert(S).second && Visitor.follow(S))
633         Worklist.push_back(S);
634     }
635 
636   public:
SCEVTraversal(SV & V)637     SCEVTraversal(SV& V): Visitor(V) {}
638 
visitAll(const SCEV * Root)639     void visitAll(const SCEV *Root) {
640       push(Root);
641       while (!Worklist.empty() && !Visitor.isDone()) {
642         const SCEV *S = Worklist.pop_back_val();
643 
644         switch (S->getSCEVType()) {
645         case scConstant:
646         case scUnknown:
647           continue;
648         case scPtrToInt:
649         case scTruncate:
650         case scZeroExtend:
651         case scSignExtend:
652           push(cast<SCEVCastExpr>(S)->getOperand());
653           continue;
654         case scAddExpr:
655         case scMulExpr:
656         case scSMaxExpr:
657         case scUMaxExpr:
658         case scSMinExpr:
659         case scUMinExpr:
660         case scAddRecExpr:
661           for (const auto *Op : cast<SCEVNAryExpr>(S)->operands())
662             push(Op);
663           continue;
664         case scUDivExpr: {
665           const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
666           push(UDiv->getLHS());
667           push(UDiv->getRHS());
668           continue;
669         }
670         case scCouldNotCompute:
671           llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
672         }
673         llvm_unreachable("Unknown SCEV kind!");
674       }
675     }
676   };
677 
678   /// Use SCEVTraversal to visit all nodes in the given expression tree.
679   template<typename SV>
visitAll(const SCEV * Root,SV & Visitor)680   void visitAll(const SCEV *Root, SV& Visitor) {
681     SCEVTraversal<SV> T(Visitor);
682     T.visitAll(Root);
683   }
684 
685   /// Return true if any node in \p Root satisfies the predicate \p Pred.
686   template <typename PredTy>
SCEVExprContains(const SCEV * Root,PredTy Pred)687   bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
688     struct FindClosure {
689       bool Found = false;
690       PredTy Pred;
691 
692       FindClosure(PredTy Pred) : Pred(Pred) {}
693 
694       bool follow(const SCEV *S) {
695         if (!Pred(S))
696           return true;
697 
698         Found = true;
699         return false;
700       }
701 
702       bool isDone() const { return Found; }
703     };
704 
705     FindClosure FC(Pred);
706     visitAll(Root, FC);
707     return FC.Found;
708   }
709 
710   /// This visitor recursively visits a SCEV expression and re-writes it.
711   /// The result from each visit is cached, so it will return the same
712   /// SCEV for the same input.
713   template<typename SC>
714   class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
715   protected:
716     ScalarEvolution &SE;
717     // Memoize the result of each visit so that we only compute once for
718     // the same input SCEV. This is to avoid redundant computations when
719     // a SCEV is referenced by multiple SCEVs. Without memoization, this
720     // visit algorithm would have exponential time complexity in the worst
721     // case, causing the compiler to hang on certain tests.
722     DenseMap<const SCEV *, const SCEV *> RewriteResults;
723 
724   public:
SCEVRewriteVisitor(ScalarEvolution & SE)725     SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
726 
visit(const SCEV * S)727     const SCEV *visit(const SCEV *S) {
728       auto It = RewriteResults.find(S);
729       if (It != RewriteResults.end())
730         return It->second;
731       auto* Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
732       auto Result = RewriteResults.try_emplace(S, Visited);
733       assert(Result.second && "Should insert a new entry");
734       return Result.first->second;
735     }
736 
visitConstant(const SCEVConstant * Constant)737     const SCEV *visitConstant(const SCEVConstant *Constant) {
738       return Constant;
739     }
740 
visitPtrToIntExpr(const SCEVPtrToIntExpr * Expr)741     const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
742       const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
743       return Operand == Expr->getOperand()
744                  ? Expr
745                  : SE.getPtrToIntExpr(Operand, Expr->getType());
746     }
747 
visitTruncateExpr(const SCEVTruncateExpr * Expr)748     const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
749       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
750       return Operand == Expr->getOperand()
751                  ? Expr
752                  : SE.getTruncateExpr(Operand, Expr->getType());
753     }
754 
visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)755     const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
756       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
757       return Operand == Expr->getOperand()
758                  ? Expr
759                  : SE.getZeroExtendExpr(Operand, Expr->getType());
760     }
761 
visitSignExtendExpr(const SCEVSignExtendExpr * Expr)762     const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
763       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
764       return Operand == Expr->getOperand()
765                  ? Expr
766                  : SE.getSignExtendExpr(Operand, Expr->getType());
767     }
768 
visitAddExpr(const SCEVAddExpr * Expr)769     const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
770       SmallVector<const SCEV *, 2> Operands;
771       bool Changed = false;
772       for (auto *Op : Expr->operands()) {
773         Operands.push_back(((SC*)this)->visit(Op));
774         Changed |= Op != Operands.back();
775       }
776       return !Changed ? Expr : SE.getAddExpr(Operands);
777     }
778 
visitMulExpr(const SCEVMulExpr * Expr)779     const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
780       SmallVector<const SCEV *, 2> Operands;
781       bool Changed = false;
782       for (auto *Op : Expr->operands()) {
783         Operands.push_back(((SC*)this)->visit(Op));
784         Changed |= Op != Operands.back();
785       }
786       return !Changed ? Expr : SE.getMulExpr(Operands);
787     }
788 
visitUDivExpr(const SCEVUDivExpr * Expr)789     const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
790       auto *LHS = ((SC *)this)->visit(Expr->getLHS());
791       auto *RHS = ((SC *)this)->visit(Expr->getRHS());
792       bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
793       return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
794     }
795 
visitAddRecExpr(const SCEVAddRecExpr * Expr)796     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
797       SmallVector<const SCEV *, 2> Operands;
798       bool Changed = false;
799       for (auto *Op : Expr->operands()) {
800         Operands.push_back(((SC*)this)->visit(Op));
801         Changed |= Op != Operands.back();
802       }
803       return !Changed ? Expr
804                       : SE.getAddRecExpr(Operands, Expr->getLoop(),
805                                          Expr->getNoWrapFlags());
806     }
807 
visitSMaxExpr(const SCEVSMaxExpr * Expr)808     const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
809       SmallVector<const SCEV *, 2> Operands;
810       bool Changed = false;
811       for (auto *Op : Expr->operands()) {
812         Operands.push_back(((SC *)this)->visit(Op));
813         Changed |= Op != Operands.back();
814       }
815       return !Changed ? Expr : SE.getSMaxExpr(Operands);
816     }
817 
visitUMaxExpr(const SCEVUMaxExpr * Expr)818     const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
819       SmallVector<const SCEV *, 2> Operands;
820       bool Changed = false;
821       for (auto *Op : Expr->operands()) {
822         Operands.push_back(((SC*)this)->visit(Op));
823         Changed |= Op != Operands.back();
824       }
825       return !Changed ? Expr : SE.getUMaxExpr(Operands);
826     }
827 
visitSMinExpr(const SCEVSMinExpr * Expr)828     const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
829       SmallVector<const SCEV *, 2> Operands;
830       bool Changed = false;
831       for (auto *Op : Expr->operands()) {
832         Operands.push_back(((SC *)this)->visit(Op));
833         Changed |= Op != Operands.back();
834       }
835       return !Changed ? Expr : SE.getSMinExpr(Operands);
836     }
837 
visitUMinExpr(const SCEVUMinExpr * Expr)838     const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
839       SmallVector<const SCEV *, 2> Operands;
840       bool Changed = false;
841       for (auto *Op : Expr->operands()) {
842         Operands.push_back(((SC *)this)->visit(Op));
843         Changed |= Op != Operands.back();
844       }
845       return !Changed ? Expr : SE.getUMinExpr(Operands);
846     }
847 
visitUnknown(const SCEVUnknown * Expr)848     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
849       return Expr;
850     }
851 
visitCouldNotCompute(const SCEVCouldNotCompute * Expr)852     const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
853       return Expr;
854     }
855   };
856 
857   using ValueToValueMap = DenseMap<const Value *, Value *>;
858   using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
859 
860   /// The SCEVParameterRewriter takes a scalar evolution expression and updates
861   /// the SCEVUnknown components following the Map (Value -> SCEV).
862   class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
863   public:
rewrite(const SCEV * Scev,ScalarEvolution & SE,ValueToSCEVMapTy & Map)864     static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
865                                ValueToSCEVMapTy &Map) {
866       SCEVParameterRewriter Rewriter(SE, Map);
867       return Rewriter.visit(Scev);
868     }
869 
SCEVParameterRewriter(ScalarEvolution & SE,ValueToSCEVMapTy & M)870     SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
871         : SCEVRewriteVisitor(SE), Map(M) {}
872 
visitUnknown(const SCEVUnknown * Expr)873     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
874       auto I = Map.find(Expr->getValue());
875       if (I == Map.end())
876         return Expr;
877       return I->second;
878     }
879 
880   private:
881     ValueToSCEVMapTy &Map;
882   };
883 
884   using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
885 
886   /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
887   /// the Map (Loop -> SCEV) to all AddRecExprs.
888   class SCEVLoopAddRecRewriter
889       : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
890   public:
SCEVLoopAddRecRewriter(ScalarEvolution & SE,LoopToScevMapT & M)891     SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
892         : SCEVRewriteVisitor(SE), Map(M) {}
893 
rewrite(const SCEV * Scev,LoopToScevMapT & Map,ScalarEvolution & SE)894     static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
895                                ScalarEvolution &SE) {
896       SCEVLoopAddRecRewriter Rewriter(SE, Map);
897       return Rewriter.visit(Scev);
898     }
899 
visitAddRecExpr(const SCEVAddRecExpr * Expr)900     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
901       SmallVector<const SCEV *, 2> Operands;
902       for (const SCEV *Op : Expr->operands())
903         Operands.push_back(visit(Op));
904 
905       const Loop *L = Expr->getLoop();
906       if (0 == Map.count(L))
907         return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
908 
909       return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
910     }
911 
912   private:
913     LoopToScevMapT &Map;
914   };
915 
916 } // end namespace llvm
917 
918 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
919