1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/iterator_range.h"
20 #include "llvm/Analysis/ScalarEvolution.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/ValueHandle.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include <cassert>
26 #include <cstddef>
27 
28 namespace llvm {
29 
30 class APInt;
31 class Constant;
32 class ConstantInt;
33 class ConstantRange;
34 class Loop;
35 class Type;
36 class Value;
37 
38 enum SCEVTypes : unsigned short {
39   // These should be ordered in terms of increasing complexity to make the
40   // folders simpler.
41   scConstant,
42   scVScale,
43   scTruncate,
44   scZeroExtend,
45   scSignExtend,
46   scAddExpr,
47   scMulExpr,
48   scUDivExpr,
49   scAddRecExpr,
50   scUMaxExpr,
51   scSMaxExpr,
52   scUMinExpr,
53   scSMinExpr,
54   scSequentialUMinExpr,
55   scPtrToInt,
56   scUnknown,
57   scCouldNotCompute
58 };
59 
60 /// This class represents a constant integer value.
61 class SCEVConstant : public SCEV {
62   friend class ScalarEvolution;
63 
64   ConstantInt *V;
65 
66   SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
67       : SCEV(ID, scConstant, 1), V(v) {}
68 
69 public:
70   ConstantInt *getValue() const { return V; }
71   const APInt &getAPInt() const { return getValue()->getValue(); }
72 
73   Type *getType() const { return V->getType(); }
74 
75   /// Methods for support type inquiry through isa, cast, and dyn_cast:
76   static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
77 };
78 
79 /// This class represents the value of vscale, as used when defining the length
80 /// of a scalable vector or returned by the llvm.vscale() intrinsic.
81 class SCEVVScale : public SCEV {
82   friend class ScalarEvolution;
83 
84   SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
85       : SCEV(ID, scVScale, 0), Ty(ty) {}
86 
87   Type *Ty;
88 
89 public:
90   Type *getType() const { return Ty; }
91 
92   /// Methods for support type inquiry through isa, cast, and dyn_cast:
93   static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
94 };
95 
96 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
97   APInt Size(16, 1);
98   for (const auto *Arg : Args)
99     Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
100   return (unsigned short)Size.getZExtValue();
101 }
102 
103 /// This is the base class for unary cast operator classes.
104 class SCEVCastExpr : public SCEV {
105 protected:
106   const SCEV *Op;
107   Type *Ty;
108 
109   SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
110                Type *ty);
111 
112 public:
113   const SCEV *getOperand() const { return Op; }
114   const SCEV *getOperand(unsigned i) const {
115     assert(i == 0 && "Operand index out of range!");
116     return Op;
117   }
118   ArrayRef<const SCEV *> operands() const { return Op; }
119   size_t getNumOperands() const { return 1; }
120   Type *getType() const { return Ty; }
121 
122   /// Methods for support type inquiry through isa, cast, and dyn_cast:
123   static bool classof(const SCEV *S) {
124     return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
125            S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
126   }
127 };
128 
129 /// This class represents a cast from a pointer to a pointer-sized integer
130 /// value.
131 class SCEVPtrToIntExpr : public SCEVCastExpr {
132   friend class ScalarEvolution;
133 
134   SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
135 
136 public:
137   /// Methods for support type inquiry through isa, cast, and dyn_cast:
138   static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
139 };
140 
141 /// This is the base class for unary integral cast operator classes.
142 class SCEVIntegralCastExpr : public SCEVCastExpr {
143 protected:
144   SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
145                        const SCEV *op, Type *ty);
146 
147 public:
148   /// Methods for support type inquiry through isa, cast, and dyn_cast:
149   static bool classof(const SCEV *S) {
150     return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
151            S->getSCEVType() == scSignExtend;
152   }
153 };
154 
155 /// This class represents a truncation of an integer value to a
156 /// smaller integer value.
157 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
158   friend class ScalarEvolution;
159 
160   SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
161 
162 public:
163   /// Methods for support type inquiry through isa, cast, and dyn_cast:
164   static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
165 };
166 
167 /// This class represents a zero extension of a small integer value
168 /// to a larger integer value.
169 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
170   friend class ScalarEvolution;
171 
172   SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
173 
174 public:
175   /// Methods for support type inquiry through isa, cast, and dyn_cast:
176   static bool classof(const SCEV *S) {
177     return S->getSCEVType() == scZeroExtend;
178   }
179 };
180 
181 /// This class represents a sign extension of a small integer value
182 /// to a larger integer value.
183 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
184   friend class ScalarEvolution;
185 
186   SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
187 
188 public:
189   /// Methods for support type inquiry through isa, cast, and dyn_cast:
190   static bool classof(const SCEV *S) {
191     return S->getSCEVType() == scSignExtend;
192   }
193 };
194 
195 /// This node is a base class providing common functionality for
196 /// n'ary operators.
197 class SCEVNAryExpr : public SCEV {
198 protected:
199   // Since SCEVs are immutable, ScalarEvolution allocates operand
200   // arrays with its SCEVAllocator, so this class just needs a simple
201   // pointer rather than a more elaborate vector-like data structure.
202   // This also avoids the need for a non-trivial destructor.
203   const SCEV *const *Operands;
204   size_t NumOperands;
205 
206   SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
207                const SCEV *const *O, size_t N)
208       : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
209         NumOperands(N) {}
210 
211 public:
212   size_t getNumOperands() const { return NumOperands; }
213 
214   const SCEV *getOperand(unsigned i) const {
215     assert(i < NumOperands && "Operand index out of range!");
216     return Operands[i];
217   }
218 
219   ArrayRef<const SCEV *> operands() const {
220     return ArrayRef(Operands, NumOperands);
221   }
222 
223   NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
224     return (NoWrapFlags)(SubclassData & Mask);
225   }
226 
227   bool hasNoUnsignedWrap() const {
228     return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
229   }
230 
231   bool hasNoSignedWrap() const {
232     return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
233   }
234 
235   bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
236 
237   /// Methods for support type inquiry through isa, cast, and dyn_cast:
238   static bool classof(const SCEV *S) {
239     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
240            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
241            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
242            S->getSCEVType() == scSequentialUMinExpr ||
243            S->getSCEVType() == scAddRecExpr;
244   }
245 };
246 
247 /// This node is the base class for n'ary commutative operators.
248 class SCEVCommutativeExpr : public SCEVNAryExpr {
249 protected:
250   SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
251                       const SCEV *const *O, size_t N)
252       : SCEVNAryExpr(ID, T, O, N) {}
253 
254 public:
255   /// Methods for support type inquiry through isa, cast, and dyn_cast:
256   static bool classof(const SCEV *S) {
257     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
258            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
259            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
260   }
261 
262   /// Set flags for a non-recurrence without clearing previously set flags.
263   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
264 };
265 
266 /// This node represents an addition of some number of SCEVs.
267 class SCEVAddExpr : public SCEVCommutativeExpr {
268   friend class ScalarEvolution;
269 
270   Type *Ty;
271 
272   SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
273       : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
274     auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
275       return Op->getType()->isPointerTy();
276     });
277     if (FirstPointerTypedOp != operands().end())
278       Ty = (*FirstPointerTypedOp)->getType();
279     else
280       Ty = getOperand(0)->getType();
281   }
282 
283 public:
284   Type *getType() const { return Ty; }
285 
286   /// Methods for support type inquiry through isa, cast, and dyn_cast:
287   static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
288 };
289 
290 /// This node represents multiplication of some number of SCEVs.
291 class SCEVMulExpr : public SCEVCommutativeExpr {
292   friend class ScalarEvolution;
293 
294   SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
295       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
296 
297 public:
298   Type *getType() const { return getOperand(0)->getType(); }
299 
300   /// Methods for support type inquiry through isa, cast, and dyn_cast:
301   static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
302 };
303 
304 /// This class represents a binary unsigned division operation.
305 class SCEVUDivExpr : public SCEV {
306   friend class ScalarEvolution;
307 
308   std::array<const SCEV *, 2> Operands;
309 
310   SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
311       : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
312     Operands[0] = lhs;
313     Operands[1] = rhs;
314   }
315 
316 public:
317   const SCEV *getLHS() const { return Operands[0]; }
318   const SCEV *getRHS() const { return Operands[1]; }
319   size_t getNumOperands() const { return 2; }
320   const SCEV *getOperand(unsigned i) const {
321     assert((i == 0 || i == 1) && "Operand index out of range!");
322     return i == 0 ? getLHS() : getRHS();
323   }
324 
325   ArrayRef<const SCEV *> operands() const { return Operands; }
326 
327   Type *getType() const {
328     // In most cases the types of LHS and RHS will be the same, but in some
329     // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
330     // depend on the type for correctness, but handling types carefully can
331     // avoid extra casts in the SCEVExpander. The LHS is more likely to be
332     // a pointer type than the RHS, so use the RHS' type here.
333     return getRHS()->getType();
334   }
335 
336   /// Methods for support type inquiry through isa, cast, and dyn_cast:
337   static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
338 };
339 
340 /// This node represents a polynomial recurrence on the trip count
341 /// of the specified loop.  This is the primary focus of the
342 /// ScalarEvolution framework; all the other SCEV subclasses are
343 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
344 /// expressions to be created and analyzed.
345 ///
346 /// All operands of an AddRec are required to be loop invariant.
347 ///
348 class SCEVAddRecExpr : public SCEVNAryExpr {
349   friend class ScalarEvolution;
350 
351   const Loop *L;
352 
353   SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
354                  const Loop *l)
355       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
356 
357 public:
358   Type *getType() const { return getStart()->getType(); }
359   const SCEV *getStart() const { return Operands[0]; }
360   const Loop *getLoop() const { return L; }
361 
362   /// Constructs and returns the recurrence indicating how much this
363   /// expression steps by.  If this is a polynomial of degree N, it
364   /// returns a chrec of degree N-1.  We cannot determine whether
365   /// the step recurrence has self-wraparound.
366   const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
367     if (isAffine())
368       return getOperand(1);
369     return SE.getAddRecExpr(
370         SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
371         FlagAnyWrap);
372   }
373 
374   /// Return true if this represents an expression A + B*x where A
375   /// and B are loop invariant values.
376   bool isAffine() const {
377     // We know that the start value is invariant.  This expression is thus
378     // affine iff the step is also invariant.
379     return getNumOperands() == 2;
380   }
381 
382   /// Return true if this represents an expression A + B*x + C*x^2
383   /// where A, B and C are loop invariant values.  This corresponds
384   /// to an addrec of the form {L,+,M,+,N}
385   bool isQuadratic() const { return getNumOperands() == 3; }
386 
387   /// Set flags for a recurrence without clearing any previously set flags.
388   /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
389   /// to make it easier to propagate flags.
390   void setNoWrapFlags(NoWrapFlags Flags) {
391     if (Flags & (FlagNUW | FlagNSW))
392       Flags = ScalarEvolution::setFlags(Flags, FlagNW);
393     SubclassData |= Flags;
394   }
395 
396   /// Return the value of this chain of recurrences at the specified
397   /// iteration number.
398   const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
399 
400   /// Return the value of this chain of recurrences at the specified iteration
401   /// number. Takes an explicit list of operands to represent an AddRec.
402   static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
403                                          const SCEV *It, ScalarEvolution &SE);
404 
405   /// Return the number of iterations of this loop that produce
406   /// values in the specified constant range.  Another way of
407   /// looking at this is that it returns the first iteration number
408   /// where the value is not in the condition, thus computing the
409   /// exit count.  If the iteration count can't be computed, an
410   /// instance of SCEVCouldNotCompute is returned.
411   const SCEV *getNumIterationsInRange(const ConstantRange &Range,
412                                       ScalarEvolution &SE) const;
413 
414   /// Return an expression representing the value of this expression
415   /// one iteration of the loop ahead.
416   const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
417 
418   /// Methods for support type inquiry through isa, cast, and dyn_cast:
419   static bool classof(const SCEV *S) {
420     return S->getSCEVType() == scAddRecExpr;
421   }
422 };
423 
424 /// This node is the base class min/max selections.
425 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
426   friend class ScalarEvolution;
427 
428   static bool isMinMaxType(enum SCEVTypes T) {
429     return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
430            T == scUMinExpr;
431   }
432 
433 protected:
434   /// Note: Constructing subclasses via this constructor is allowed
435   SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
436                  const SCEV *const *O, size_t N)
437       : SCEVCommutativeExpr(ID, T, O, N) {
438     assert(isMinMaxType(T));
439     // Min and max never overflow
440     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
441   }
442 
443 public:
444   Type *getType() const { return getOperand(0)->getType(); }
445 
446   static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
447 
448   static enum SCEVTypes negate(enum SCEVTypes T) {
449     switch (T) {
450     case scSMaxExpr:
451       return scSMinExpr;
452     case scSMinExpr:
453       return scSMaxExpr;
454     case scUMaxExpr:
455       return scUMinExpr;
456     case scUMinExpr:
457       return scUMaxExpr;
458     default:
459       llvm_unreachable("Not a min or max SCEV type!");
460     }
461   }
462 };
463 
464 /// This class represents a signed maximum selection.
465 class SCEVSMaxExpr : public SCEVMinMaxExpr {
466   friend class ScalarEvolution;
467 
468   SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
469       : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
470 
471 public:
472   /// Methods for support type inquiry through isa, cast, and dyn_cast:
473   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
474 };
475 
476 /// This class represents an unsigned maximum selection.
477 class SCEVUMaxExpr : public SCEVMinMaxExpr {
478   friend class ScalarEvolution;
479 
480   SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
481       : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
482 
483 public:
484   /// Methods for support type inquiry through isa, cast, and dyn_cast:
485   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
486 };
487 
488 /// This class represents a signed minimum selection.
489 class SCEVSMinExpr : public SCEVMinMaxExpr {
490   friend class ScalarEvolution;
491 
492   SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
493       : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
494 
495 public:
496   /// Methods for support type inquiry through isa, cast, and dyn_cast:
497   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
498 };
499 
500 /// This class represents an unsigned minimum selection.
501 class SCEVUMinExpr : public SCEVMinMaxExpr {
502   friend class ScalarEvolution;
503 
504   SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
505       : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
506 
507 public:
508   /// Methods for support type inquiry through isa, cast, and dyn_cast:
509   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
510 };
511 
512 /// This node is the base class for sequential/in-order min/max selections.
513 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
514 /// are early-returning upon reaching saturation point.
515 /// I.e. given `0 umin_seq poison`, the result will be `0`,
516 /// while the result of `0 umin poison` is `poison`.
517 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
518   friend class ScalarEvolution;
519 
520   static bool isSequentialMinMaxType(enum SCEVTypes T) {
521     return T == scSequentialUMinExpr;
522   }
523 
524   /// Set flags for a non-recurrence without clearing previously set flags.
525   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
526 
527 protected:
528   /// Note: Constructing subclasses via this constructor is allowed
529   SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
530                            const SCEV *const *O, size_t N)
531       : SCEVNAryExpr(ID, T, O, N) {
532     assert(isSequentialMinMaxType(T));
533     // Min and max never overflow
534     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
535   }
536 
537 public:
538   Type *getType() const { return getOperand(0)->getType(); }
539 
540   static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
541     assert(isSequentialMinMaxType(Ty));
542     switch (Ty) {
543     case scSequentialUMinExpr:
544       return scUMinExpr;
545     default:
546       llvm_unreachable("Not a sequential min/max type.");
547     }
548   }
549 
550   SCEVTypes getEquivalentNonSequentialSCEVType() const {
551     return getEquivalentNonSequentialSCEVType(getSCEVType());
552   }
553 
554   static bool classof(const SCEV *S) {
555     return isSequentialMinMaxType(S->getSCEVType());
556   }
557 };
558 
559 /// This class represents a sequential/in-order unsigned minimum selection.
560 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
561   friend class ScalarEvolution;
562 
563   SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
564                          size_t N)
565       : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
566 
567 public:
568   /// Methods for support type inquiry through isa, cast, and dyn_cast:
569   static bool classof(const SCEV *S) {
570     return S->getSCEVType() == scSequentialUMinExpr;
571   }
572 };
573 
574 /// This means that we are dealing with an entirely unknown SCEV
575 /// value, and only represent it as its LLVM Value.  This is the
576 /// "bottom" value for the analysis.
577 class SCEVUnknown final : public SCEV, private CallbackVH {
578   friend class ScalarEvolution;
579 
580   /// The parent ScalarEvolution value. This is used to update the
581   /// parent's maps when the value associated with a SCEVUnknown is
582   /// deleted or RAUW'd.
583   ScalarEvolution *SE;
584 
585   /// The next pointer in the linked list of all SCEVUnknown
586   /// instances owned by a ScalarEvolution.
587   SCEVUnknown *Next;
588 
589   SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
590               SCEVUnknown *next)
591       : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
592 
593   // Implement CallbackVH.
594   void deleted() override;
595   void allUsesReplacedWith(Value *New) override;
596 
597 public:
598   Value *getValue() const { return getValPtr(); }
599 
600   Type *getType() const { return getValPtr()->getType(); }
601 
602   /// Methods for support type inquiry through isa, cast, and dyn_cast:
603   static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
604 };
605 
606 /// This class defines a simple visitor class that may be used for
607 /// various SCEV analysis purposes.
608 template <typename SC, typename RetVal = void> struct SCEVVisitor {
609   RetVal visit(const SCEV *S) {
610     switch (S->getSCEVType()) {
611     case scConstant:
612       return ((SC *)this)->visitConstant((const SCEVConstant *)S);
613     case scVScale:
614       return ((SC *)this)->visitVScale((const SCEVVScale *)S);
615     case scPtrToInt:
616       return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
617     case scTruncate:
618       return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
619     case scZeroExtend:
620       return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
621     case scSignExtend:
622       return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
623     case scAddExpr:
624       return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
625     case scMulExpr:
626       return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
627     case scUDivExpr:
628       return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
629     case scAddRecExpr:
630       return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
631     case scSMaxExpr:
632       return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
633     case scUMaxExpr:
634       return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
635     case scSMinExpr:
636       return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
637     case scUMinExpr:
638       return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
639     case scSequentialUMinExpr:
640       return ((SC *)this)
641           ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
642     case scUnknown:
643       return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
644     case scCouldNotCompute:
645       return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
646     }
647     llvm_unreachable("Unknown SCEV kind!");
648   }
649 
650   RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
651     llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
652   }
653 };
654 
655 /// Visit all nodes in the expression tree using worklist traversal.
656 ///
657 /// Visitor implements:
658 ///   // return true to follow this node.
659 ///   bool follow(const SCEV *S);
660 ///   // return true to terminate the search.
661 ///   bool isDone();
662 template <typename SV> class SCEVTraversal {
663   SV &Visitor;
664   SmallVector<const SCEV *, 8> Worklist;
665   SmallPtrSet<const SCEV *, 8> Visited;
666 
667   void push(const SCEV *S) {
668     if (Visited.insert(S).second && Visitor.follow(S))
669       Worklist.push_back(S);
670   }
671 
672 public:
673   SCEVTraversal(SV &V) : Visitor(V) {}
674 
675   void visitAll(const SCEV *Root) {
676     push(Root);
677     while (!Worklist.empty() && !Visitor.isDone()) {
678       const SCEV *S = Worklist.pop_back_val();
679 
680       switch (S->getSCEVType()) {
681       case scConstant:
682       case scVScale:
683       case scUnknown:
684         continue;
685       case scPtrToInt:
686       case scTruncate:
687       case scZeroExtend:
688       case scSignExtend:
689       case scAddExpr:
690       case scMulExpr:
691       case scUDivExpr:
692       case scSMaxExpr:
693       case scUMaxExpr:
694       case scSMinExpr:
695       case scUMinExpr:
696       case scSequentialUMinExpr:
697       case scAddRecExpr:
698         for (const auto *Op : S->operands()) {
699           push(Op);
700           if (Visitor.isDone())
701             break;
702         }
703         continue;
704       case scCouldNotCompute:
705         llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
706       }
707       llvm_unreachable("Unknown SCEV kind!");
708     }
709   }
710 };
711 
712 /// Use SCEVTraversal to visit all nodes in the given expression tree.
713 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
714   SCEVTraversal<SV> T(Visitor);
715   T.visitAll(Root);
716 }
717 
718 /// Return true if any node in \p Root satisfies the predicate \p Pred.
719 template <typename PredTy>
720 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
721   struct FindClosure {
722     bool Found = false;
723     PredTy Pred;
724 
725     FindClosure(PredTy Pred) : Pred(Pred) {}
726 
727     bool follow(const SCEV *S) {
728       if (!Pred(S))
729         return true;
730 
731       Found = true;
732       return false;
733     }
734 
735     bool isDone() const { return Found; }
736   };
737 
738   FindClosure FC(Pred);
739   visitAll(Root, FC);
740   return FC.Found;
741 }
742 
743 /// This visitor recursively visits a SCEV expression and re-writes it.
744 /// The result from each visit is cached, so it will return the same
745 /// SCEV for the same input.
746 template <typename SC>
747 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
748 protected:
749   ScalarEvolution &SE;
750   // Memoize the result of each visit so that we only compute once for
751   // the same input SCEV. This is to avoid redundant computations when
752   // a SCEV is referenced by multiple SCEVs. Without memoization, this
753   // visit algorithm would have exponential time complexity in the worst
754   // case, causing the compiler to hang on certain tests.
755   SmallDenseMap<const SCEV *, const SCEV *> RewriteResults;
756 
757 public:
758   SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
759 
760   const SCEV *visit(const SCEV *S) {
761     auto It = RewriteResults.find(S);
762     if (It != RewriteResults.end())
763       return It->second;
764     auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
765     auto Result = RewriteResults.try_emplace(S, Visited);
766     assert(Result.second && "Should insert a new entry");
767     return Result.first->second;
768   }
769 
770   const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
771 
772   const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
773 
774   const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
775     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
776     return Operand == Expr->getOperand()
777                ? Expr
778                : SE.getPtrToIntExpr(Operand, Expr->getType());
779   }
780 
781   const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
782     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
783     return Operand == Expr->getOperand()
784                ? Expr
785                : SE.getTruncateExpr(Operand, Expr->getType());
786   }
787 
788   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
789     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
790     return Operand == Expr->getOperand()
791                ? Expr
792                : SE.getZeroExtendExpr(Operand, Expr->getType());
793   }
794 
795   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
796     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
797     return Operand == Expr->getOperand()
798                ? Expr
799                : SE.getSignExtendExpr(Operand, Expr->getType());
800   }
801 
802   const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
803     SmallVector<const SCEV *, 2> Operands;
804     bool Changed = false;
805     for (const auto *Op : Expr->operands()) {
806       Operands.push_back(((SC *)this)->visit(Op));
807       Changed |= Op != Operands.back();
808     }
809     return !Changed ? Expr : SE.getAddExpr(Operands);
810   }
811 
812   const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
813     SmallVector<const SCEV *, 2> Operands;
814     bool Changed = false;
815     for (const auto *Op : Expr->operands()) {
816       Operands.push_back(((SC *)this)->visit(Op));
817       Changed |= Op != Operands.back();
818     }
819     return !Changed ? Expr : SE.getMulExpr(Operands);
820   }
821 
822   const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
823     auto *LHS = ((SC *)this)->visit(Expr->getLHS());
824     auto *RHS = ((SC *)this)->visit(Expr->getRHS());
825     bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
826     return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
827   }
828 
829   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
830     SmallVector<const SCEV *, 2> Operands;
831     bool Changed = false;
832     for (const auto *Op : Expr->operands()) {
833       Operands.push_back(((SC *)this)->visit(Op));
834       Changed |= Op != Operands.back();
835     }
836     return !Changed ? Expr
837                     : SE.getAddRecExpr(Operands, Expr->getLoop(),
838                                        Expr->getNoWrapFlags());
839   }
840 
841   const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
842     SmallVector<const SCEV *, 2> Operands;
843     bool Changed = false;
844     for (const auto *Op : Expr->operands()) {
845       Operands.push_back(((SC *)this)->visit(Op));
846       Changed |= Op != Operands.back();
847     }
848     return !Changed ? Expr : SE.getSMaxExpr(Operands);
849   }
850 
851   const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
852     SmallVector<const SCEV *, 2> Operands;
853     bool Changed = false;
854     for (const auto *Op : Expr->operands()) {
855       Operands.push_back(((SC *)this)->visit(Op));
856       Changed |= Op != Operands.back();
857     }
858     return !Changed ? Expr : SE.getUMaxExpr(Operands);
859   }
860 
861   const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
862     SmallVector<const SCEV *, 2> Operands;
863     bool Changed = false;
864     for (const auto *Op : Expr->operands()) {
865       Operands.push_back(((SC *)this)->visit(Op));
866       Changed |= Op != Operands.back();
867     }
868     return !Changed ? Expr : SE.getSMinExpr(Operands);
869   }
870 
871   const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
872     SmallVector<const SCEV *, 2> Operands;
873     bool Changed = false;
874     for (const auto *Op : Expr->operands()) {
875       Operands.push_back(((SC *)this)->visit(Op));
876       Changed |= Op != Operands.back();
877     }
878     return !Changed ? Expr : SE.getUMinExpr(Operands);
879   }
880 
881   const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
882     SmallVector<const SCEV *, 2> Operands;
883     bool Changed = false;
884     for (const auto *Op : Expr->operands()) {
885       Operands.push_back(((SC *)this)->visit(Op));
886       Changed |= Op != Operands.back();
887     }
888     return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
889   }
890 
891   const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
892 
893   const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
894     return Expr;
895   }
896 };
897 
898 using ValueToValueMap = DenseMap<const Value *, Value *>;
899 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
900 
901 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
902 /// the SCEVUnknown components following the Map (Value -> SCEV).
903 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
904 public:
905   static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
906                              ValueToSCEVMapTy &Map) {
907     SCEVParameterRewriter Rewriter(SE, Map);
908     return Rewriter.visit(Scev);
909   }
910 
911   SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
912       : SCEVRewriteVisitor(SE), Map(M) {}
913 
914   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
915     auto I = Map.find(Expr->getValue());
916     if (I == Map.end())
917       return Expr;
918     return I->second;
919   }
920 
921 private:
922   ValueToSCEVMapTy &Map;
923 };
924 
925 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
926 
927 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
928 /// the Map (Loop -> SCEV) to all AddRecExprs.
929 class SCEVLoopAddRecRewriter
930     : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
931 public:
932   SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
933       : SCEVRewriteVisitor(SE), Map(M) {}
934 
935   static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
936                              ScalarEvolution &SE) {
937     SCEVLoopAddRecRewriter Rewriter(SE, Map);
938     return Rewriter.visit(Scev);
939   }
940 
941   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
942     SmallVector<const SCEV *, 2> Operands;
943     for (const SCEV *Op : Expr->operands())
944       Operands.push_back(visit(Op));
945 
946     const Loop *L = Expr->getLoop();
947     if (0 == Map.count(L))
948       return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
949 
950     return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
951   }
952 
953 private:
954   LoopToScevMapT &Map;
955 };
956 
957 } // end namespace llvm
958 
959 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
960