1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/FoldingSet.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/ScalarEvolution.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/IR/ValueHandle.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include <cassert>
28 #include <cstddef>
29 
30 namespace llvm {
31 
32 class APInt;
33 class Constant;
34 class ConstantRange;
35 class Loop;
36 class Type;
37 
38   enum SCEVTypes {
39     // These should be ordered in terms of increasing complexity to make the
40     // folders simpler.
41     scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
42     scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUMinExpr, scSMinExpr,
43     scUnknown, scCouldNotCompute
44   };
45 
46   /// This class represents a constant integer value.
47   class SCEVConstant : public SCEV {
48     friend class ScalarEvolution;
49 
50     ConstantInt *V;
51 
SCEVConstant(const FoldingSetNodeIDRef ID,ConstantInt * v)52     SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) :
53       SCEV(ID, scConstant, 1), V(v) {}
54 
55   public:
getValue()56     ConstantInt *getValue() const { return V; }
getAPInt()57     const APInt &getAPInt() const { return getValue()->getValue(); }
58 
getType()59     Type *getType() const { return V->getType(); }
60 
61     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)62     static bool classof(const SCEV *S) {
63       return S->getSCEVType() == scConstant;
64     }
65   };
66 
computeExpressionSize(ArrayRef<const SCEV * > Args)67   inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
68     APInt Size(16, 1);
69     for (auto *Arg : Args)
70       Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
71     return (unsigned short)Size.getZExtValue();
72   }
73 
74   /// This is the base class for unary cast operator classes.
75   class SCEVCastExpr : public SCEV {
76   protected:
77     const SCEV *Op;
78     Type *Ty;
79 
80     SCEVCastExpr(const FoldingSetNodeIDRef ID,
81                  unsigned SCEVTy, const SCEV *op, Type *ty);
82 
83   public:
getOperand()84     const SCEV *getOperand() const { return Op; }
getType()85     Type *getType() const { return Ty; }
86 
87     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)88     static bool classof(const SCEV *S) {
89       return S->getSCEVType() == scTruncate ||
90              S->getSCEVType() == scZeroExtend ||
91              S->getSCEVType() == scSignExtend;
92     }
93   };
94 
95   /// This class represents a truncation of an integer value to a
96   /// smaller integer value.
97   class SCEVTruncateExpr : public SCEVCastExpr {
98     friend class ScalarEvolution;
99 
100     SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
101                      const SCEV *op, Type *ty);
102 
103   public:
104     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)105     static bool classof(const SCEV *S) {
106       return S->getSCEVType() == scTruncate;
107     }
108   };
109 
110   /// This class represents a zero extension of a small integer value
111   /// to a larger integer value.
112   class SCEVZeroExtendExpr : public SCEVCastExpr {
113     friend class ScalarEvolution;
114 
115     SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
116                        const SCEV *op, Type *ty);
117 
118   public:
119     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)120     static bool classof(const SCEV *S) {
121       return S->getSCEVType() == scZeroExtend;
122     }
123   };
124 
125   /// This class represents a sign extension of a small integer value
126   /// to a larger integer value.
127   class SCEVSignExtendExpr : public SCEVCastExpr {
128     friend class ScalarEvolution;
129 
130     SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
131                        const SCEV *op, Type *ty);
132 
133   public:
134     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)135     static bool classof(const SCEV *S) {
136       return S->getSCEVType() == scSignExtend;
137     }
138   };
139 
140   /// This node is a base class providing common functionality for
141   /// n'ary operators.
142   class SCEVNAryExpr : public SCEV {
143   protected:
144     // Since SCEVs are immutable, ScalarEvolution allocates operand
145     // arrays with its SCEVAllocator, so this class just needs a simple
146     // pointer rather than a more elaborate vector-like data structure.
147     // This also avoids the need for a non-trivial destructor.
148     const SCEV *const *Operands;
149     size_t NumOperands;
150 
SCEVNAryExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)151     SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
152                  const SCEV *const *O, size_t N)
153         : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O),
154           NumOperands(N) {}
155 
156   public:
getNumOperands()157     size_t getNumOperands() const { return NumOperands; }
158 
getOperand(unsigned i)159     const SCEV *getOperand(unsigned i) const {
160       assert(i < NumOperands && "Operand index out of range!");
161       return Operands[i];
162     }
163 
164     using op_iterator = const SCEV *const *;
165     using op_range = iterator_range<op_iterator>;
166 
op_begin()167     op_iterator op_begin() const { return Operands; }
op_end()168     op_iterator op_end() const { return Operands + NumOperands; }
operands()169     op_range operands() const {
170       return make_range(op_begin(), op_end());
171     }
172 
getType()173     Type *getType() const { return getOperand(0)->getType(); }
174 
175     NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
176       return (NoWrapFlags)(SubclassData & Mask);
177     }
178 
hasNoUnsignedWrap()179     bool hasNoUnsignedWrap() const {
180       return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
181     }
182 
hasNoSignedWrap()183     bool hasNoSignedWrap() const {
184       return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
185     }
186 
hasNoSelfWrap()187     bool hasNoSelfWrap() const {
188       return getNoWrapFlags(FlagNW) != FlagAnyWrap;
189     }
190 
191     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)192     static bool classof(const SCEV *S) {
193       return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
194              S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
195              S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
196              S->getSCEVType() == scAddRecExpr;
197     }
198   };
199 
200   /// This node is the base class for n'ary commutative operators.
201   class SCEVCommutativeExpr : public SCEVNAryExpr {
202   protected:
SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)203     SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,
204                         enum SCEVTypes T, const SCEV *const *O, size_t N)
205       : SCEVNAryExpr(ID, T, O, N) {}
206 
207   public:
208     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)209     static bool classof(const SCEV *S) {
210       return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
211              S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
212              S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
213     }
214 
215     /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)216     void setNoWrapFlags(NoWrapFlags Flags) {
217       SubclassData |= Flags;
218     }
219   };
220 
221   /// This node represents an addition of some number of SCEVs.
222   class SCEVAddExpr : public SCEVCommutativeExpr {
223     friend class ScalarEvolution;
224 
225     Type *Ty;
226 
SCEVAddExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)227     SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
228         : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
229       auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
230         return Op->getType()->isPointerTy();
231       });
232       if (FirstPointerTypedOp != operands().end())
233         Ty = (*FirstPointerTypedOp)->getType();
234       else
235         Ty = getOperand(0)->getType();
236     }
237 
238   public:
getType()239     Type *getType() const { return Ty; }
240 
241     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)242     static bool classof(const SCEV *S) {
243       return S->getSCEVType() == scAddExpr;
244     }
245   };
246 
247   /// This node represents multiplication of some number of SCEVs.
248   class SCEVMulExpr : public SCEVCommutativeExpr {
249     friend class ScalarEvolution;
250 
SCEVMulExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)251     SCEVMulExpr(const FoldingSetNodeIDRef ID,
252                 const SCEV *const *O, size_t N)
253       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
254 
255   public:
256     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)257     static bool classof(const SCEV *S) {
258       return S->getSCEVType() == scMulExpr;
259     }
260   };
261 
262   /// This class represents a binary unsigned division operation.
263   class SCEVUDivExpr : public SCEV {
264     friend class ScalarEvolution;
265 
266     const SCEV *LHS;
267     const SCEV *RHS;
268 
SCEVUDivExpr(const FoldingSetNodeIDRef ID,const SCEV * lhs,const SCEV * rhs)269     SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
270         : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})), LHS(lhs),
271           RHS(rhs) {}
272 
273   public:
getLHS()274     const SCEV *getLHS() const { return LHS; }
getRHS()275     const SCEV *getRHS() const { return RHS; }
276 
getType()277     Type *getType() const {
278       // In most cases the types of LHS and RHS will be the same, but in some
279       // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
280       // depend on the type for correctness, but handling types carefully can
281       // avoid extra casts in the SCEVExpander. The LHS is more likely to be
282       // a pointer type than the RHS, so use the RHS' type here.
283       return getRHS()->getType();
284     }
285 
286     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)287     static bool classof(const SCEV *S) {
288       return S->getSCEVType() == scUDivExpr;
289     }
290   };
291 
292   /// This node represents a polynomial recurrence on the trip count
293   /// of the specified loop.  This is the primary focus of the
294   /// ScalarEvolution framework; all the other SCEV subclasses are
295   /// mostly just supporting infrastructure to allow SCEVAddRecExpr
296   /// expressions to be created and analyzed.
297   ///
298   /// All operands of an AddRec are required to be loop invariant.
299   ///
300   class SCEVAddRecExpr : public SCEVNAryExpr {
301     friend class ScalarEvolution;
302 
303     const Loop *L;
304 
SCEVAddRecExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N,const Loop * l)305     SCEVAddRecExpr(const FoldingSetNodeIDRef ID,
306                    const SCEV *const *O, size_t N, const Loop *l)
307       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
308 
309   public:
getStart()310     const SCEV *getStart() const { return Operands[0]; }
getLoop()311     const Loop *getLoop() const { return L; }
312 
313     /// Constructs and returns the recurrence indicating how much this
314     /// expression steps by.  If this is a polynomial of degree N, it
315     /// returns a chrec of degree N-1.  We cannot determine whether
316     /// the step recurrence has self-wraparound.
getStepRecurrence(ScalarEvolution & SE)317     const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
318       if (isAffine()) return getOperand(1);
319       return SE.getAddRecExpr(SmallVector<const SCEV *, 3>(op_begin()+1,
320                                                            op_end()),
321                               getLoop(), FlagAnyWrap);
322     }
323 
324     /// Return true if this represents an expression A + B*x where A
325     /// and B are loop invariant values.
isAffine()326     bool isAffine() const {
327       // We know that the start value is invariant.  This expression is thus
328       // affine iff the step is also invariant.
329       return getNumOperands() == 2;
330     }
331 
332     /// Return true if this represents an expression A + B*x + C*x^2
333     /// where A, B and C are loop invariant values.  This corresponds
334     /// to an addrec of the form {L,+,M,+,N}
isQuadratic()335     bool isQuadratic() const {
336       return getNumOperands() == 3;
337     }
338 
339     /// Set flags for a recurrence without clearing any previously set flags.
340     /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
341     /// to make it easier to propagate flags.
setNoWrapFlags(NoWrapFlags Flags)342     void setNoWrapFlags(NoWrapFlags Flags) {
343       if (Flags & (FlagNUW | FlagNSW))
344         Flags = ScalarEvolution::setFlags(Flags, FlagNW);
345       SubclassData |= Flags;
346     }
347 
348     /// Return the value of this chain of recurrences at the specified
349     /// iteration number.
350     const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
351 
352     /// Return the number of iterations of this loop that produce
353     /// values in the specified constant range.  Another way of
354     /// looking at this is that it returns the first iteration number
355     /// where the value is not in the condition, thus computing the
356     /// exit count.  If the iteration count can't be computed, an
357     /// instance of SCEVCouldNotCompute is returned.
358     const SCEV *getNumIterationsInRange(const ConstantRange &Range,
359                                         ScalarEvolution &SE) const;
360 
361     /// Return an expression representing the value of this expression
362     /// one iteration of the loop ahead.
363     const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
364 
365     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)366     static bool classof(const SCEV *S) {
367       return S->getSCEVType() == scAddRecExpr;
368     }
369   };
370 
371   /// This node is the base class min/max selections.
372   class SCEVMinMaxExpr : public SCEVCommutativeExpr {
373     friend class ScalarEvolution;
374 
isMinMaxType(enum SCEVTypes T)375     static bool isMinMaxType(enum SCEVTypes T) {
376       return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
377              T == scUMinExpr;
378     }
379 
380   protected:
381     /// Note: Constructing subclasses via this constructor is allowed
SCEVMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)382     SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
383                    const SCEV *const *O, size_t N)
384         : SCEVCommutativeExpr(ID, T, O, N) {
385       assert(isMinMaxType(T));
386       // Min and max never overflow
387       setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
388     }
389 
390   public:
classof(const SCEV * S)391     static bool classof(const SCEV *S) {
392       return isMinMaxType(static_cast<SCEVTypes>(S->getSCEVType()));
393     }
394 
negate(enum SCEVTypes T)395     static enum SCEVTypes negate(enum SCEVTypes T) {
396       switch (T) {
397       case scSMaxExpr:
398         return scSMinExpr;
399       case scSMinExpr:
400         return scSMaxExpr;
401       case scUMaxExpr:
402         return scUMinExpr;
403       case scUMinExpr:
404         return scUMaxExpr;
405       default:
406         llvm_unreachable("Not a min or max SCEV type!");
407       }
408     }
409   };
410 
411   /// This class represents a signed maximum selection.
412   class SCEVSMaxExpr : public SCEVMinMaxExpr {
413     friend class ScalarEvolution;
414 
SCEVSMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)415     SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
416         : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
417 
418   public:
419     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)420     static bool classof(const SCEV *S) {
421       return S->getSCEVType() == scSMaxExpr;
422     }
423   };
424 
425   /// This class represents an unsigned maximum selection.
426   class SCEVUMaxExpr : public SCEVMinMaxExpr {
427     friend class ScalarEvolution;
428 
SCEVUMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)429     SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
430         : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
431 
432   public:
433     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)434     static bool classof(const SCEV *S) {
435       return S->getSCEVType() == scUMaxExpr;
436     }
437   };
438 
439   /// This class represents a signed minimum selection.
440   class SCEVSMinExpr : public SCEVMinMaxExpr {
441     friend class ScalarEvolution;
442 
SCEVSMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)443     SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
444         : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
445 
446   public:
447     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)448     static bool classof(const SCEV *S) {
449       return S->getSCEVType() == scSMinExpr;
450     }
451   };
452 
453   /// This class represents an unsigned minimum selection.
454   class SCEVUMinExpr : public SCEVMinMaxExpr {
455     friend class ScalarEvolution;
456 
SCEVUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)457     SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
458         : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
459 
460   public:
461     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)462     static bool classof(const SCEV *S) {
463       return S->getSCEVType() == scUMinExpr;
464     }
465   };
466 
467   /// This means that we are dealing with an entirely unknown SCEV
468   /// value, and only represent it as its LLVM Value.  This is the
469   /// "bottom" value for the analysis.
470   class SCEVUnknown final : public SCEV, private CallbackVH {
471     friend class ScalarEvolution;
472 
473     /// The parent ScalarEvolution value. This is used to update the
474     /// parent's maps when the value associated with a SCEVUnknown is
475     /// deleted or RAUW'd.
476     ScalarEvolution *SE;
477 
478     /// The next pointer in the linked list of all SCEVUnknown
479     /// instances owned by a ScalarEvolution.
480     SCEVUnknown *Next;
481 
SCEVUnknown(const FoldingSetNodeIDRef ID,Value * V,ScalarEvolution * se,SCEVUnknown * next)482     SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V,
483                 ScalarEvolution *se, SCEVUnknown *next) :
484       SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
485 
486     // Implement CallbackVH.
487     void deleted() override;
488     void allUsesReplacedWith(Value *New) override;
489 
490   public:
getValue()491     Value *getValue() const { return getValPtr(); }
492 
493     /// @{
494     /// Test whether this is a special constant representing a type
495     /// size, alignment, or field offset in a target-independent
496     /// manner, and hasn't happened to have been folded with other
497     /// operations into something unrecognizable. This is mainly only
498     /// useful for pretty-printing and other situations where it isn't
499     /// absolutely required for these to succeed.
500     bool isSizeOf(Type *&AllocTy) const;
501     bool isAlignOf(Type *&AllocTy) const;
502     bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
503     /// @}
504 
getType()505     Type *getType() const { return getValPtr()->getType(); }
506 
507     /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)508     static bool classof(const SCEV *S) {
509       return S->getSCEVType() == scUnknown;
510     }
511   };
512 
513   /// This class defines a simple visitor class that may be used for
514   /// various SCEV analysis purposes.
515   template<typename SC, typename RetVal=void>
516   struct SCEVVisitor {
visitSCEVVisitor517     RetVal visit(const SCEV *S) {
518       switch (S->getSCEVType()) {
519       case scConstant:
520         return ((SC*)this)->visitConstant((const SCEVConstant*)S);
521       case scTruncate:
522         return ((SC*)this)->visitTruncateExpr((const SCEVTruncateExpr*)S);
523       case scZeroExtend:
524         return ((SC*)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr*)S);
525       case scSignExtend:
526         return ((SC*)this)->visitSignExtendExpr((const SCEVSignExtendExpr*)S);
527       case scAddExpr:
528         return ((SC*)this)->visitAddExpr((const SCEVAddExpr*)S);
529       case scMulExpr:
530         return ((SC*)this)->visitMulExpr((const SCEVMulExpr*)S);
531       case scUDivExpr:
532         return ((SC*)this)->visitUDivExpr((const SCEVUDivExpr*)S);
533       case scAddRecExpr:
534         return ((SC*)this)->visitAddRecExpr((const SCEVAddRecExpr*)S);
535       case scSMaxExpr:
536         return ((SC*)this)->visitSMaxExpr((const SCEVSMaxExpr*)S);
537       case scUMaxExpr:
538         return ((SC*)this)->visitUMaxExpr((const SCEVUMaxExpr*)S);
539       case scSMinExpr:
540         return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
541       case scUMinExpr:
542         return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
543       case scUnknown:
544         return ((SC*)this)->visitUnknown((const SCEVUnknown*)S);
545       case scCouldNotCompute:
546         return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S);
547       default:
548         llvm_unreachable("Unknown SCEV type!");
549       }
550     }
551 
visitCouldNotComputeSCEVVisitor552     RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
553       llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
554     }
555   };
556 
557   /// Visit all nodes in the expression tree using worklist traversal.
558   ///
559   /// Visitor implements:
560   ///   // return true to follow this node.
561   ///   bool follow(const SCEV *S);
562   ///   // return true to terminate the search.
563   ///   bool isDone();
564   template<typename SV>
565   class SCEVTraversal {
566     SV &Visitor;
567     SmallVector<const SCEV *, 8> Worklist;
568     SmallPtrSet<const SCEV *, 8> Visited;
569 
push(const SCEV * S)570     void push(const SCEV *S) {
571       if (Visited.insert(S).second && Visitor.follow(S))
572         Worklist.push_back(S);
573     }
574 
575   public:
SCEVTraversal(SV & V)576     SCEVTraversal(SV& V): Visitor(V) {}
577 
visitAll(const SCEV * Root)578     void visitAll(const SCEV *Root) {
579       push(Root);
580       while (!Worklist.empty() && !Visitor.isDone()) {
581         const SCEV *S = Worklist.pop_back_val();
582 
583         switch (S->getSCEVType()) {
584         case scConstant:
585         case scUnknown:
586           break;
587         case scTruncate:
588         case scZeroExtend:
589         case scSignExtend:
590           push(cast<SCEVCastExpr>(S)->getOperand());
591           break;
592         case scAddExpr:
593         case scMulExpr:
594         case scSMaxExpr:
595         case scUMaxExpr:
596         case scSMinExpr:
597         case scUMinExpr:
598         case scAddRecExpr:
599           for (const auto *Op : cast<SCEVNAryExpr>(S)->operands())
600             push(Op);
601           break;
602         case scUDivExpr: {
603           const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
604           push(UDiv->getLHS());
605           push(UDiv->getRHS());
606           break;
607         }
608         case scCouldNotCompute:
609           llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
610         default:
611           llvm_unreachable("Unknown SCEV kind!");
612         }
613       }
614     }
615   };
616 
617   /// Use SCEVTraversal to visit all nodes in the given expression tree.
618   template<typename SV>
visitAll(const SCEV * Root,SV & Visitor)619   void visitAll(const SCEV *Root, SV& Visitor) {
620     SCEVTraversal<SV> T(Visitor);
621     T.visitAll(Root);
622   }
623 
624   /// Return true if any node in \p Root satisfies the predicate \p Pred.
625   template <typename PredTy>
SCEVExprContains(const SCEV * Root,PredTy Pred)626   bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
627     struct FindClosure {
628       bool Found = false;
629       PredTy Pred;
630 
631       FindClosure(PredTy Pred) : Pred(Pred) {}
632 
633       bool follow(const SCEV *S) {
634         if (!Pred(S))
635           return true;
636 
637         Found = true;
638         return false;
639       }
640 
641       bool isDone() const { return Found; }
642     };
643 
644     FindClosure FC(Pred);
645     visitAll(Root, FC);
646     return FC.Found;
647   }
648 
649   /// This visitor recursively visits a SCEV expression and re-writes it.
650   /// The result from each visit is cached, so it will return the same
651   /// SCEV for the same input.
652   template<typename SC>
653   class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
654   protected:
655     ScalarEvolution &SE;
656     // Memoize the result of each visit so that we only compute once for
657     // the same input SCEV. This is to avoid redundant computations when
658     // a SCEV is referenced by multiple SCEVs. Without memoization, this
659     // visit algorithm would have exponential time complexity in the worst
660     // case, causing the compiler to hang on certain tests.
661     DenseMap<const SCEV *, const SCEV *> RewriteResults;
662 
663   public:
SCEVRewriteVisitor(ScalarEvolution & SE)664     SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
665 
visit(const SCEV * S)666     const SCEV *visit(const SCEV *S) {
667       auto It = RewriteResults.find(S);
668       if (It != RewriteResults.end())
669         return It->second;
670       auto* Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
671       auto Result = RewriteResults.try_emplace(S, Visited);
672       assert(Result.second && "Should insert a new entry");
673       return Result.first->second;
674     }
675 
visitConstant(const SCEVConstant * Constant)676     const SCEV *visitConstant(const SCEVConstant *Constant) {
677       return Constant;
678     }
679 
visitTruncateExpr(const SCEVTruncateExpr * Expr)680     const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
681       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
682       return Operand == Expr->getOperand()
683                  ? Expr
684                  : SE.getTruncateExpr(Operand, Expr->getType());
685     }
686 
visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)687     const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
688       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
689       return Operand == Expr->getOperand()
690                  ? Expr
691                  : SE.getZeroExtendExpr(Operand, Expr->getType());
692     }
693 
visitSignExtendExpr(const SCEVSignExtendExpr * Expr)694     const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
695       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
696       return Operand == Expr->getOperand()
697                  ? Expr
698                  : SE.getSignExtendExpr(Operand, Expr->getType());
699     }
700 
visitAddExpr(const SCEVAddExpr * Expr)701     const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
702       SmallVector<const SCEV *, 2> Operands;
703       bool Changed = false;
704       for (auto *Op : Expr->operands()) {
705         Operands.push_back(((SC*)this)->visit(Op));
706         Changed |= Op != Operands.back();
707       }
708       return !Changed ? Expr : SE.getAddExpr(Operands);
709     }
710 
visitMulExpr(const SCEVMulExpr * Expr)711     const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
712       SmallVector<const SCEV *, 2> Operands;
713       bool Changed = false;
714       for (auto *Op : Expr->operands()) {
715         Operands.push_back(((SC*)this)->visit(Op));
716         Changed |= Op != Operands.back();
717       }
718       return !Changed ? Expr : SE.getMulExpr(Operands);
719     }
720 
visitUDivExpr(const SCEVUDivExpr * Expr)721     const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
722       auto *LHS = ((SC *)this)->visit(Expr->getLHS());
723       auto *RHS = ((SC *)this)->visit(Expr->getRHS());
724       bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
725       return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
726     }
727 
visitAddRecExpr(const SCEVAddRecExpr * Expr)728     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
729       SmallVector<const SCEV *, 2> Operands;
730       bool Changed = false;
731       for (auto *Op : Expr->operands()) {
732         Operands.push_back(((SC*)this)->visit(Op));
733         Changed |= Op != Operands.back();
734       }
735       return !Changed ? Expr
736                       : SE.getAddRecExpr(Operands, Expr->getLoop(),
737                                          Expr->getNoWrapFlags());
738     }
739 
visitSMaxExpr(const SCEVSMaxExpr * Expr)740     const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
741       SmallVector<const SCEV *, 2> Operands;
742       bool Changed = false;
743       for (auto *Op : Expr->operands()) {
744         Operands.push_back(((SC *)this)->visit(Op));
745         Changed |= Op != Operands.back();
746       }
747       return !Changed ? Expr : SE.getSMaxExpr(Operands);
748     }
749 
visitUMaxExpr(const SCEVUMaxExpr * Expr)750     const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
751       SmallVector<const SCEV *, 2> Operands;
752       bool Changed = false;
753       for (auto *Op : Expr->operands()) {
754         Operands.push_back(((SC*)this)->visit(Op));
755         Changed |= Op != Operands.back();
756       }
757       return !Changed ? Expr : SE.getUMaxExpr(Operands);
758     }
759 
visitSMinExpr(const SCEVSMinExpr * Expr)760     const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
761       SmallVector<const SCEV *, 2> Operands;
762       bool Changed = false;
763       for (auto *Op : Expr->operands()) {
764         Operands.push_back(((SC *)this)->visit(Op));
765         Changed |= Op != Operands.back();
766       }
767       return !Changed ? Expr : SE.getSMinExpr(Operands);
768     }
769 
visitUMinExpr(const SCEVUMinExpr * Expr)770     const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
771       SmallVector<const SCEV *, 2> Operands;
772       bool Changed = false;
773       for (auto *Op : Expr->operands()) {
774         Operands.push_back(((SC *)this)->visit(Op));
775         Changed |= Op != Operands.back();
776       }
777       return !Changed ? Expr : SE.getUMinExpr(Operands);
778     }
779 
visitUnknown(const SCEVUnknown * Expr)780     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
781       return Expr;
782     }
783 
visitCouldNotCompute(const SCEVCouldNotCompute * Expr)784     const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
785       return Expr;
786     }
787   };
788 
789   using ValueToValueMap = DenseMap<const Value *, Value *>;
790 
791   /// The SCEVParameterRewriter takes a scalar evolution expression and updates
792   /// the SCEVUnknown components following the Map (Value -> Value).
793   class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
794   public:
795     static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
796                                ValueToValueMap &Map,
797                                bool InterpretConsts = false) {
798       SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts);
799       return Rewriter.visit(Scev);
800     }
801 
SCEVParameterRewriter(ScalarEvolution & SE,ValueToValueMap & M,bool C)802     SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C)
803       : SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {}
804 
visitUnknown(const SCEVUnknown * Expr)805     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
806       Value *V = Expr->getValue();
807       if (Map.count(V)) {
808         Value *NV = Map[V];
809         if (InterpretConsts && isa<ConstantInt>(NV))
810           return SE.getConstant(cast<ConstantInt>(NV));
811         return SE.getUnknown(NV);
812       }
813       return Expr;
814     }
815 
816   private:
817     ValueToValueMap &Map;
818     bool InterpretConsts;
819   };
820 
821   using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
822 
823   /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
824   /// the Map (Loop -> SCEV) to all AddRecExprs.
825   class SCEVLoopAddRecRewriter
826       : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
827   public:
SCEVLoopAddRecRewriter(ScalarEvolution & SE,LoopToScevMapT & M)828     SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
829         : SCEVRewriteVisitor(SE), Map(M) {}
830 
rewrite(const SCEV * Scev,LoopToScevMapT & Map,ScalarEvolution & SE)831     static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
832                                ScalarEvolution &SE) {
833       SCEVLoopAddRecRewriter Rewriter(SE, Map);
834       return Rewriter.visit(Scev);
835     }
836 
visitAddRecExpr(const SCEVAddRecExpr * Expr)837     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
838       SmallVector<const SCEV *, 2> Operands;
839       for (const SCEV *Op : Expr->operands())
840         Operands.push_back(visit(Op));
841 
842       const Loop *L = Expr->getLoop();
843       const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
844 
845       if (0 == Map.count(L))
846         return Res;
847 
848       const SCEVAddRecExpr *Rec = cast<SCEVAddRecExpr>(Res);
849       return Rec->evaluateAtIteration(Map[L], SE);
850     }
851 
852   private:
853     LoopToScevMapT &Map;
854   };
855 
856 } // end namespace llvm
857 
858 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
859