1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/FoldingSet.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/ScalarEvolution.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/IR/ValueHandle.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include <cassert>
28 #include <cstddef>
29 
30 namespace llvm {
31 
32 class APInt;
33 class Constant;
34 class ConstantRange;
35 class Loop;
36 class Type;
37 
38   enum SCEVTypes : unsigned short {
39     // These should be ordered in terms of increasing complexity to make the
40     // folders simpler.
41     scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
42     scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUMinExpr, scSMinExpr,
43     scPtrToInt, scUnknown, scCouldNotCompute
44   };
45 
46   /// This class represents a constant integer value.
47   class SCEVConstant : public SCEV {
48     friend class ScalarEvolution;
49 
50     ConstantInt *V;
51 
SCEVConstant(const FoldingSetNodeIDRef ID,ConstantInt * v)52     SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) :
53       SCEV(ID, scConstant, 1), V(v) {}
54 
55   public:
getValue()56     ConstantInt *getValue() const { return V; }
getAPInt()57     const APInt &getAPInt() const { return getValue()->getValue(); }
58 
getType()59     Type *getType() const { return V->getType(); }
60 
61     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)62     static bool classof(const SCEV *S) {
63       return S->getSCEVType() == scConstant;
64     }
65   };
66 
computeExpressionSize(ArrayRef<const SCEV * > Args)67   inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
68     APInt Size(16, 1);
69     for (auto *Arg : Args)
70       Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
71     return (unsigned short)Size.getZExtValue();
72   }
73 
74   /// This is the base class for unary cast operator classes.
75   class SCEVCastExpr : public SCEV {
76   protected:
77     std::array<const SCEV *, 1> Operands;
78     Type *Ty;
79 
80     SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
81                  Type *ty);
82 
83   public:
getOperand()84     const SCEV *getOperand() const { return Operands[0]; }
getOperand(unsigned i)85     const SCEV *getOperand(unsigned i) const {
86       assert(i == 0 && "Operand index out of range!");
87       return Operands[0];
88     }
89     using op_iterator = std::array<const SCEV *, 1>::const_iterator;
90     using op_range = iterator_range<op_iterator>;
91 
operands()92     op_range operands() const {
93       return make_range(Operands.begin(), Operands.end());
94     }
getNumOperands()95     size_t getNumOperands() const { return 1; }
getType()96     Type *getType() const { return Ty; }
97 
98     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)99     static bool classof(const SCEV *S) {
100       return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
101              S->getSCEVType() == scZeroExtend ||
102              S->getSCEVType() == scSignExtend;
103     }
104   };
105 
106   /// This class represents a cast from a pointer to a pointer-sized integer
107   /// value.
108   class SCEVPtrToIntExpr : public SCEVCastExpr {
109     friend class ScalarEvolution;
110 
111     SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
112 
113   public:
114     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)115     static bool classof(const SCEV *S) {
116       return S->getSCEVType() == scPtrToInt;
117     }
118   };
119 
120   /// This is the base class for unary integral cast operator classes.
121   class SCEVIntegralCastExpr : public SCEVCastExpr {
122   protected:
123     SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
124                          const SCEV *op, Type *ty);
125 
126   public:
127     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)128     static bool classof(const SCEV *S) {
129       return S->getSCEVType() == scTruncate ||
130              S->getSCEVType() == scZeroExtend ||
131              S->getSCEVType() == scSignExtend;
132     }
133   };
134 
135   /// This class represents a truncation of an integer value to a
136   /// smaller integer value.
137   class SCEVTruncateExpr : public SCEVIntegralCastExpr {
138     friend class ScalarEvolution;
139 
140     SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
141                      const SCEV *op, Type *ty);
142 
143   public:
144     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)145     static bool classof(const SCEV *S) {
146       return S->getSCEVType() == scTruncate;
147     }
148   };
149 
150   /// This class represents a zero extension of a small integer value
151   /// to a larger integer value.
152   class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
153     friend class ScalarEvolution;
154 
155     SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
156                        const SCEV *op, Type *ty);
157 
158   public:
159     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)160     static bool classof(const SCEV *S) {
161       return S->getSCEVType() == scZeroExtend;
162     }
163   };
164 
165   /// This class represents a sign extension of a small integer value
166   /// to a larger integer value.
167   class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
168     friend class ScalarEvolution;
169 
170     SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
171                        const SCEV *op, Type *ty);
172 
173   public:
174     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)175     static bool classof(const SCEV *S) {
176       return S->getSCEVType() == scSignExtend;
177     }
178   };
179 
180   /// This node is a base class providing common functionality for
181   /// n'ary operators.
182   class SCEVNAryExpr : public SCEV {
183   protected:
184     // Since SCEVs are immutable, ScalarEvolution allocates operand
185     // arrays with its SCEVAllocator, so this class just needs a simple
186     // pointer rather than a more elaborate vector-like data structure.
187     // This also avoids the need for a non-trivial destructor.
188     const SCEV *const *Operands;
189     size_t NumOperands;
190 
SCEVNAryExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)191     SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
192                  const SCEV *const *O, size_t N)
193         : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O),
194           NumOperands(N) {}
195 
196   public:
getNumOperands()197     size_t getNumOperands() const { return NumOperands; }
198 
getOperand(unsigned i)199     const SCEV *getOperand(unsigned i) const {
200       assert(i < NumOperands && "Operand index out of range!");
201       return Operands[i];
202     }
203 
204     using op_iterator = const SCEV *const *;
205     using op_range = iterator_range<op_iterator>;
206 
op_begin()207     op_iterator op_begin() const { return Operands; }
op_end()208     op_iterator op_end() const { return Operands + NumOperands; }
operands()209     op_range operands() const {
210       return make_range(op_begin(), op_end());
211     }
212 
getType()213     Type *getType() const { return getOperand(0)->getType(); }
214 
215     NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
216       return (NoWrapFlags)(SubclassData & Mask);
217     }
218 
hasNoUnsignedWrap()219     bool hasNoUnsignedWrap() const {
220       return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
221     }
222 
hasNoSignedWrap()223     bool hasNoSignedWrap() const {
224       return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
225     }
226 
hasNoSelfWrap()227     bool hasNoSelfWrap() const {
228       return getNoWrapFlags(FlagNW) != FlagAnyWrap;
229     }
230 
231     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)232     static bool classof(const SCEV *S) {
233       return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
234              S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
235              S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
236              S->getSCEVType() == scAddRecExpr;
237     }
238   };
239 
240   /// This node is the base class for n'ary commutative operators.
241   class SCEVCommutativeExpr : public SCEVNAryExpr {
242   protected:
SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)243     SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,
244                         enum SCEVTypes T, const SCEV *const *O, size_t N)
245       : SCEVNAryExpr(ID, T, O, N) {}
246 
247   public:
248     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)249     static bool classof(const SCEV *S) {
250       return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
251              S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
252              S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
253     }
254 
255     /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)256     void setNoWrapFlags(NoWrapFlags Flags) {
257       SubclassData |= Flags;
258     }
259   };
260 
261   /// This node represents an addition of some number of SCEVs.
262   class SCEVAddExpr : public SCEVCommutativeExpr {
263     friend class ScalarEvolution;
264 
265     Type *Ty;
266 
SCEVAddExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)267     SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
268         : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
269       auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
270         return Op->getType()->isPointerTy();
271       });
272       if (FirstPointerTypedOp != operands().end())
273         Ty = (*FirstPointerTypedOp)->getType();
274       else
275         Ty = getOperand(0)->getType();
276     }
277 
278   public:
getType()279     Type *getType() const { return Ty; }
280 
281     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)282     static bool classof(const SCEV *S) {
283       return S->getSCEVType() == scAddExpr;
284     }
285   };
286 
287   /// This node represents multiplication of some number of SCEVs.
288   class SCEVMulExpr : public SCEVCommutativeExpr {
289     friend class ScalarEvolution;
290 
SCEVMulExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)291     SCEVMulExpr(const FoldingSetNodeIDRef ID,
292                 const SCEV *const *O, size_t N)
293       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
294 
295   public:
296     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)297     static bool classof(const SCEV *S) {
298       return S->getSCEVType() == scMulExpr;
299     }
300   };
301 
302   /// This class represents a binary unsigned division operation.
303   class SCEVUDivExpr : public SCEV {
304     friend class ScalarEvolution;
305 
306     std::array<const SCEV *, 2> Operands;
307 
SCEVUDivExpr(const FoldingSetNodeIDRef ID,const SCEV * lhs,const SCEV * rhs)308     SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
309         : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
310         Operands[0] = lhs;
311         Operands[1] = rhs;
312       }
313 
314   public:
getLHS()315     const SCEV *getLHS() const { return Operands[0]; }
getRHS()316     const SCEV *getRHS() const { return Operands[1]; }
getNumOperands()317     size_t getNumOperands() const { return 2; }
getOperand(unsigned i)318     const SCEV *getOperand(unsigned i) const {
319       assert((i == 0 || i == 1) && "Operand index out of range!");
320       return i == 0 ? getLHS() : getRHS();
321     }
322 
323     using op_iterator = std::array<const SCEV *, 2>::const_iterator;
324     using op_range = iterator_range<op_iterator>;
operands()325     op_range operands() const {
326       return make_range(Operands.begin(), Operands.end());
327     }
328 
getType()329     Type *getType() const {
330       // In most cases the types of LHS and RHS will be the same, but in some
331       // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
332       // depend on the type for correctness, but handling types carefully can
333       // avoid extra casts in the SCEVExpander. The LHS is more likely to be
334       // a pointer type than the RHS, so use the RHS' type here.
335       return getRHS()->getType();
336     }
337 
338     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)339     static bool classof(const SCEV *S) {
340       return S->getSCEVType() == scUDivExpr;
341     }
342   };
343 
344   /// This node represents a polynomial recurrence on the trip count
345   /// of the specified loop.  This is the primary focus of the
346   /// ScalarEvolution framework; all the other SCEV subclasses are
347   /// mostly just supporting infrastructure to allow SCEVAddRecExpr
348   /// expressions to be created and analyzed.
349   ///
350   /// All operands of an AddRec are required to be loop invariant.
351   ///
352   class SCEVAddRecExpr : public SCEVNAryExpr {
353     friend class ScalarEvolution;
354 
355     const Loop *L;
356 
SCEVAddRecExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N,const Loop * l)357     SCEVAddRecExpr(const FoldingSetNodeIDRef ID,
358                    const SCEV *const *O, size_t N, const Loop *l)
359       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
360 
361   public:
getStart()362     const SCEV *getStart() const { return Operands[0]; }
getLoop()363     const Loop *getLoop() const { return L; }
364 
365     /// Constructs and returns the recurrence indicating how much this
366     /// expression steps by.  If this is a polynomial of degree N, it
367     /// returns a chrec of degree N-1.  We cannot determine whether
368     /// the step recurrence has self-wraparound.
getStepRecurrence(ScalarEvolution & SE)369     const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
370       if (isAffine()) return getOperand(1);
371       return SE.getAddRecExpr(SmallVector<const SCEV *, 3>(op_begin()+1,
372                                                            op_end()),
373                               getLoop(), FlagAnyWrap);
374     }
375 
376     /// Return true if this represents an expression A + B*x where A
377     /// and B are loop invariant values.
isAffine()378     bool isAffine() const {
379       // We know that the start value is invariant.  This expression is thus
380       // affine iff the step is also invariant.
381       return getNumOperands() == 2;
382     }
383 
384     /// Return true if this represents an expression A + B*x + C*x^2
385     /// where A, B and C are loop invariant values.  This corresponds
386     /// to an addrec of the form {L,+,M,+,N}
isQuadratic()387     bool isQuadratic() const {
388       return getNumOperands() == 3;
389     }
390 
391     /// Set flags for a recurrence without clearing any previously set flags.
392     /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
393     /// to make it easier to propagate flags.
setNoWrapFlags(NoWrapFlags Flags)394     void setNoWrapFlags(NoWrapFlags Flags) {
395       if (Flags & (FlagNUW | FlagNSW))
396         Flags = ScalarEvolution::setFlags(Flags, FlagNW);
397       SubclassData |= Flags;
398     }
399 
400     /// Return the value of this chain of recurrences at the specified
401     /// iteration number.
402     const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
403 
404     /// Return the number of iterations of this loop that produce
405     /// values in the specified constant range.  Another way of
406     /// looking at this is that it returns the first iteration number
407     /// where the value is not in the condition, thus computing the
408     /// exit count.  If the iteration count can't be computed, an
409     /// instance of SCEVCouldNotCompute is returned.
410     const SCEV *getNumIterationsInRange(const ConstantRange &Range,
411                                         ScalarEvolution &SE) const;
412 
413     /// Return an expression representing the value of this expression
414     /// one iteration of the loop ahead.
415     const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
416 
417     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)418     static bool classof(const SCEV *S) {
419       return S->getSCEVType() == scAddRecExpr;
420     }
421   };
422 
423   /// This node is the base class min/max selections.
424   class SCEVMinMaxExpr : public SCEVCommutativeExpr {
425     friend class ScalarEvolution;
426 
isMinMaxType(enum SCEVTypes T)427     static bool isMinMaxType(enum SCEVTypes T) {
428       return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
429              T == scUMinExpr;
430     }
431 
432   protected:
433     /// Note: Constructing subclasses via this constructor is allowed
SCEVMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)434     SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
435                    const SCEV *const *O, size_t N)
436         : SCEVCommutativeExpr(ID, T, O, N) {
437       assert(isMinMaxType(T));
438       // Min and max never overflow
439       setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
440     }
441 
442   public:
classof(const SCEV * S)443     static bool classof(const SCEV *S) {
444       return isMinMaxType(S->getSCEVType());
445     }
446 
negate(enum SCEVTypes T)447     static enum SCEVTypes negate(enum SCEVTypes T) {
448       switch (T) {
449       case scSMaxExpr:
450         return scSMinExpr;
451       case scSMinExpr:
452         return scSMaxExpr;
453       case scUMaxExpr:
454         return scUMinExpr;
455       case scUMinExpr:
456         return scUMaxExpr;
457       default:
458         llvm_unreachable("Not a min or max SCEV type!");
459       }
460     }
461   };
462 
463   /// This class represents a signed maximum selection.
464   class SCEVSMaxExpr : public SCEVMinMaxExpr {
465     friend class ScalarEvolution;
466 
SCEVSMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)467     SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
468         : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
469 
470   public:
471     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)472     static bool classof(const SCEV *S) {
473       return S->getSCEVType() == scSMaxExpr;
474     }
475   };
476 
477   /// This class represents an unsigned maximum selection.
478   class SCEVUMaxExpr : public SCEVMinMaxExpr {
479     friend class ScalarEvolution;
480 
SCEVUMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)481     SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
482         : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
483 
484   public:
485     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)486     static bool classof(const SCEV *S) {
487       return S->getSCEVType() == scUMaxExpr;
488     }
489   };
490 
491   /// This class represents a signed minimum selection.
492   class SCEVSMinExpr : public SCEVMinMaxExpr {
493     friend class ScalarEvolution;
494 
SCEVSMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)495     SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
496         : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
497 
498   public:
499     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)500     static bool classof(const SCEV *S) {
501       return S->getSCEVType() == scSMinExpr;
502     }
503   };
504 
505   /// This class represents an unsigned minimum selection.
506   class SCEVUMinExpr : public SCEVMinMaxExpr {
507     friend class ScalarEvolution;
508 
SCEVUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)509     SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
510         : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
511 
512   public:
513     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)514     static bool classof(const SCEV *S) {
515       return S->getSCEVType() == scUMinExpr;
516     }
517   };
518 
519   /// This means that we are dealing with an entirely unknown SCEV
520   /// value, and only represent it as its LLVM Value.  This is the
521   /// "bottom" value for the analysis.
522   class SCEVUnknown final : public SCEV, private CallbackVH {
523     friend class ScalarEvolution;
524 
525     /// The parent ScalarEvolution value. This is used to update the
526     /// parent's maps when the value associated with a SCEVUnknown is
527     /// deleted or RAUW'd.
528     ScalarEvolution *SE;
529 
530     /// The next pointer in the linked list of all SCEVUnknown
531     /// instances owned by a ScalarEvolution.
532     SCEVUnknown *Next;
533 
SCEVUnknown(const FoldingSetNodeIDRef ID,Value * V,ScalarEvolution * se,SCEVUnknown * next)534     SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V,
535                 ScalarEvolution *se, SCEVUnknown *next) :
536       SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
537 
538     // Implement CallbackVH.
539     void deleted() override;
540     void allUsesReplacedWith(Value *New) override;
541 
542   public:
getValue()543     Value *getValue() const { return getValPtr(); }
544 
545     /// @{
546     /// Test whether this is a special constant representing a type
547     /// size, alignment, or field offset in a target-independent
548     /// manner, and hasn't happened to have been folded with other
549     /// operations into something unrecognizable. This is mainly only
550     /// useful for pretty-printing and other situations where it isn't
551     /// absolutely required for these to succeed.
552     bool isSizeOf(Type *&AllocTy) const;
553     bool isAlignOf(Type *&AllocTy) const;
554     bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
555     /// @}
556 
getType()557     Type *getType() const { return getValPtr()->getType(); }
558 
559     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)560     static bool classof(const SCEV *S) {
561       return S->getSCEVType() == scUnknown;
562     }
563   };
564 
565   /// This class defines a simple visitor class that may be used for
566   /// various SCEV analysis purposes.
567   template<typename SC, typename RetVal=void>
568   struct SCEVVisitor {
visitSCEVVisitor569     RetVal visit(const SCEV *S) {
570       switch (S->getSCEVType()) {
571       case scConstant:
572         return ((SC*)this)->visitConstant((const SCEVConstant*)S);
573       case scPtrToInt:
574         return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
575       case scTruncate:
576         return ((SC*)this)->visitTruncateExpr((const SCEVTruncateExpr*)S);
577       case scZeroExtend:
578         return ((SC*)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr*)S);
579       case scSignExtend:
580         return ((SC*)this)->visitSignExtendExpr((const SCEVSignExtendExpr*)S);
581       case scAddExpr:
582         return ((SC*)this)->visitAddExpr((const SCEVAddExpr*)S);
583       case scMulExpr:
584         return ((SC*)this)->visitMulExpr((const SCEVMulExpr*)S);
585       case scUDivExpr:
586         return ((SC*)this)->visitUDivExpr((const SCEVUDivExpr*)S);
587       case scAddRecExpr:
588         return ((SC*)this)->visitAddRecExpr((const SCEVAddRecExpr*)S);
589       case scSMaxExpr:
590         return ((SC*)this)->visitSMaxExpr((const SCEVSMaxExpr*)S);
591       case scUMaxExpr:
592         return ((SC*)this)->visitUMaxExpr((const SCEVUMaxExpr*)S);
593       case scSMinExpr:
594         return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
595       case scUMinExpr:
596         return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
597       case scUnknown:
598         return ((SC*)this)->visitUnknown((const SCEVUnknown*)S);
599       case scCouldNotCompute:
600         return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S);
601       }
602       llvm_unreachable("Unknown SCEV kind!");
603     }
604 
visitCouldNotComputeSCEVVisitor605     RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
606       llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
607     }
608   };
609 
610   /// Visit all nodes in the expression tree using worklist traversal.
611   ///
612   /// Visitor implements:
613   ///   // return true to follow this node.
614   ///   bool follow(const SCEV *S);
615   ///   // return true to terminate the search.
616   ///   bool isDone();
617   template<typename SV>
618   class SCEVTraversal {
619     SV &Visitor;
620     SmallVector<const SCEV *, 8> Worklist;
621     SmallPtrSet<const SCEV *, 8> Visited;
622 
push(const SCEV * S)623     void push(const SCEV *S) {
624       if (Visited.insert(S).second && Visitor.follow(S))
625         Worklist.push_back(S);
626     }
627 
628   public:
SCEVTraversal(SV & V)629     SCEVTraversal(SV& V): Visitor(V) {}
630 
visitAll(const SCEV * Root)631     void visitAll(const SCEV *Root) {
632       push(Root);
633       while (!Worklist.empty() && !Visitor.isDone()) {
634         const SCEV *S = Worklist.pop_back_val();
635 
636         switch (S->getSCEVType()) {
637         case scConstant:
638         case scUnknown:
639           continue;
640         case scPtrToInt:
641         case scTruncate:
642         case scZeroExtend:
643         case scSignExtend:
644           push(cast<SCEVCastExpr>(S)->getOperand());
645           continue;
646         case scAddExpr:
647         case scMulExpr:
648         case scSMaxExpr:
649         case scUMaxExpr:
650         case scSMinExpr:
651         case scUMinExpr:
652         case scAddRecExpr:
653           for (const auto *Op : cast<SCEVNAryExpr>(S)->operands())
654             push(Op);
655           continue;
656         case scUDivExpr: {
657           const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
658           push(UDiv->getLHS());
659           push(UDiv->getRHS());
660           continue;
661         }
662         case scCouldNotCompute:
663           llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
664         }
665         llvm_unreachable("Unknown SCEV kind!");
666       }
667     }
668   };
669 
670   /// Use SCEVTraversal to visit all nodes in the given expression tree.
671   template<typename SV>
visitAll(const SCEV * Root,SV & Visitor)672   void visitAll(const SCEV *Root, SV& Visitor) {
673     SCEVTraversal<SV> T(Visitor);
674     T.visitAll(Root);
675   }
676 
677   /// Return true if any node in \p Root satisfies the predicate \p Pred.
678   template <typename PredTy>
SCEVExprContains(const SCEV * Root,PredTy Pred)679   bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
680     struct FindClosure {
681       bool Found = false;
682       PredTy Pred;
683 
684       FindClosure(PredTy Pred) : Pred(Pred) {}
685 
686       bool follow(const SCEV *S) {
687         if (!Pred(S))
688           return true;
689 
690         Found = true;
691         return false;
692       }
693 
694       bool isDone() const { return Found; }
695     };
696 
697     FindClosure FC(Pred);
698     visitAll(Root, FC);
699     return FC.Found;
700   }
701 
702   /// This visitor recursively visits a SCEV expression and re-writes it.
703   /// The result from each visit is cached, so it will return the same
704   /// SCEV for the same input.
705   template<typename SC>
706   class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
707   protected:
708     ScalarEvolution &SE;
709     // Memoize the result of each visit so that we only compute once for
710     // the same input SCEV. This is to avoid redundant computations when
711     // a SCEV is referenced by multiple SCEVs. Without memoization, this
712     // visit algorithm would have exponential time complexity in the worst
713     // case, causing the compiler to hang on certain tests.
714     DenseMap<const SCEV *, const SCEV *> RewriteResults;
715 
716   public:
SCEVRewriteVisitor(ScalarEvolution & SE)717     SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
718 
visit(const SCEV * S)719     const SCEV *visit(const SCEV *S) {
720       auto It = RewriteResults.find(S);
721       if (It != RewriteResults.end())
722         return It->second;
723       auto* Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
724       auto Result = RewriteResults.try_emplace(S, Visited);
725       assert(Result.second && "Should insert a new entry");
726       return Result.first->second;
727     }
728 
visitConstant(const SCEVConstant * Constant)729     const SCEV *visitConstant(const SCEVConstant *Constant) {
730       return Constant;
731     }
732 
visitPtrToIntExpr(const SCEVPtrToIntExpr * Expr)733     const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
734       const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
735       return Operand == Expr->getOperand()
736                  ? Expr
737                  : SE.getPtrToIntExpr(Operand, Expr->getType());
738     }
739 
visitTruncateExpr(const SCEVTruncateExpr * Expr)740     const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
741       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
742       return Operand == Expr->getOperand()
743                  ? Expr
744                  : SE.getTruncateExpr(Operand, Expr->getType());
745     }
746 
visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)747     const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
748       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
749       return Operand == Expr->getOperand()
750                  ? Expr
751                  : SE.getZeroExtendExpr(Operand, Expr->getType());
752     }
753 
visitSignExtendExpr(const SCEVSignExtendExpr * Expr)754     const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
755       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
756       return Operand == Expr->getOperand()
757                  ? Expr
758                  : SE.getSignExtendExpr(Operand, Expr->getType());
759     }
760 
visitAddExpr(const SCEVAddExpr * Expr)761     const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
762       SmallVector<const SCEV *, 2> Operands;
763       bool Changed = false;
764       for (auto *Op : Expr->operands()) {
765         Operands.push_back(((SC*)this)->visit(Op));
766         Changed |= Op != Operands.back();
767       }
768       return !Changed ? Expr : SE.getAddExpr(Operands);
769     }
770 
visitMulExpr(const SCEVMulExpr * Expr)771     const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
772       SmallVector<const SCEV *, 2> Operands;
773       bool Changed = false;
774       for (auto *Op : Expr->operands()) {
775         Operands.push_back(((SC*)this)->visit(Op));
776         Changed |= Op != Operands.back();
777       }
778       return !Changed ? Expr : SE.getMulExpr(Operands);
779     }
780 
visitUDivExpr(const SCEVUDivExpr * Expr)781     const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
782       auto *LHS = ((SC *)this)->visit(Expr->getLHS());
783       auto *RHS = ((SC *)this)->visit(Expr->getRHS());
784       bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
785       return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
786     }
787 
visitAddRecExpr(const SCEVAddRecExpr * Expr)788     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
789       SmallVector<const SCEV *, 2> Operands;
790       bool Changed = false;
791       for (auto *Op : Expr->operands()) {
792         Operands.push_back(((SC*)this)->visit(Op));
793         Changed |= Op != Operands.back();
794       }
795       return !Changed ? Expr
796                       : SE.getAddRecExpr(Operands, Expr->getLoop(),
797                                          Expr->getNoWrapFlags());
798     }
799 
visitSMaxExpr(const SCEVSMaxExpr * Expr)800     const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
801       SmallVector<const SCEV *, 2> Operands;
802       bool Changed = false;
803       for (auto *Op : Expr->operands()) {
804         Operands.push_back(((SC *)this)->visit(Op));
805         Changed |= Op != Operands.back();
806       }
807       return !Changed ? Expr : SE.getSMaxExpr(Operands);
808     }
809 
visitUMaxExpr(const SCEVUMaxExpr * Expr)810     const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
811       SmallVector<const SCEV *, 2> Operands;
812       bool Changed = false;
813       for (auto *Op : Expr->operands()) {
814         Operands.push_back(((SC*)this)->visit(Op));
815         Changed |= Op != Operands.back();
816       }
817       return !Changed ? Expr : SE.getUMaxExpr(Operands);
818     }
819 
visitSMinExpr(const SCEVSMinExpr * Expr)820     const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
821       SmallVector<const SCEV *, 2> Operands;
822       bool Changed = false;
823       for (auto *Op : Expr->operands()) {
824         Operands.push_back(((SC *)this)->visit(Op));
825         Changed |= Op != Operands.back();
826       }
827       return !Changed ? Expr : SE.getSMinExpr(Operands);
828     }
829 
visitUMinExpr(const SCEVUMinExpr * Expr)830     const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
831       SmallVector<const SCEV *, 2> Operands;
832       bool Changed = false;
833       for (auto *Op : Expr->operands()) {
834         Operands.push_back(((SC *)this)->visit(Op));
835         Changed |= Op != Operands.back();
836       }
837       return !Changed ? Expr : SE.getUMinExpr(Operands);
838     }
839 
visitUnknown(const SCEVUnknown * Expr)840     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
841       return Expr;
842     }
843 
visitCouldNotCompute(const SCEVCouldNotCompute * Expr)844     const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
845       return Expr;
846     }
847   };
848 
849   using ValueToValueMap = DenseMap<const Value *, Value *>;
850   using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
851 
852   /// The SCEVParameterRewriter takes a scalar evolution expression and updates
853   /// the SCEVUnknown components following the Map (Value -> SCEV).
854   class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
855   public:
rewrite(const SCEV * Scev,ScalarEvolution & SE,ValueToSCEVMapTy & Map)856     static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
857                                ValueToSCEVMapTy &Map) {
858       SCEVParameterRewriter Rewriter(SE, Map);
859       return Rewriter.visit(Scev);
860     }
861 
SCEVParameterRewriter(ScalarEvolution & SE,ValueToSCEVMapTy & M)862     SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
863         : SCEVRewriteVisitor(SE), Map(M) {}
864 
visitUnknown(const SCEVUnknown * Expr)865     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
866       auto I = Map.find(Expr->getValue());
867       if (I == Map.end())
868         return Expr;
869       return I->second;
870     }
871 
872   private:
873     ValueToSCEVMapTy &Map;
874   };
875 
876   using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
877 
878   /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
879   /// the Map (Loop -> SCEV) to all AddRecExprs.
880   class SCEVLoopAddRecRewriter
881       : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
882   public:
SCEVLoopAddRecRewriter(ScalarEvolution & SE,LoopToScevMapT & M)883     SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
884         : SCEVRewriteVisitor(SE), Map(M) {}
885 
rewrite(const SCEV * Scev,LoopToScevMapT & Map,ScalarEvolution & SE)886     static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
887                                ScalarEvolution &SE) {
888       SCEVLoopAddRecRewriter Rewriter(SE, Map);
889       return Rewriter.visit(Scev);
890     }
891 
visitAddRecExpr(const SCEVAddRecExpr * Expr)892     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
893       SmallVector<const SCEV *, 2> Operands;
894       for (const SCEV *Op : Expr->operands())
895         Operands.push_back(visit(Op));
896 
897       const Loop *L = Expr->getLoop();
898       const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
899 
900       if (0 == Map.count(L))
901         return Res;
902 
903       const SCEVAddRecExpr *Rec = cast<SCEVAddRecExpr>(Res);
904       return Rec->evaluateAtIteration(Map[L], SE);
905     }
906 
907   private:
908     LoopToScevMapT &Map;
909   };
910 
911 } // end namespace llvm
912 
913 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
914