1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/iterator_range.h"
20 #include "llvm/Analysis/ScalarEvolution.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/ValueHandle.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include <cassert>
26 #include <cstddef>
27 
28 namespace llvm {
29 
30 class APInt;
31 class Constant;
32 class ConstantInt;
33 class ConstantRange;
34 class Loop;
35 class Type;
36 class Value;
37 
38 enum SCEVTypes : unsigned short {
39   // These should be ordered in terms of increasing complexity to make the
40   // folders simpler.
41   scConstant,
42   scTruncate,
43   scZeroExtend,
44   scSignExtend,
45   scAddExpr,
46   scMulExpr,
47   scUDivExpr,
48   scAddRecExpr,
49   scUMaxExpr,
50   scSMaxExpr,
51   scUMinExpr,
52   scSMinExpr,
53   scSequentialUMinExpr,
54   scPtrToInt,
55   scUnknown,
56   scCouldNotCompute
57 };
58 
59 /// This class represents a constant integer value.
60 class SCEVConstant : public SCEV {
61   friend class ScalarEvolution;
62 
63   ConstantInt *V;
64 
65   SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66       : SCEV(ID, scConstant, 1), V(v) {}
67 
68 public:
69   ConstantInt *getValue() const { return V; }
70   const APInt &getAPInt() const { return getValue()->getValue(); }
71 
72   Type *getType() const { return V->getType(); }
73 
74   /// Methods for support type inquiry through isa, cast, and dyn_cast:
75   static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76 };
77 
78 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
79   APInt Size(16, 1);
80   for (const auto *Arg : Args)
81     Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
82   return (unsigned short)Size.getZExtValue();
83 }
84 
85 /// This is the base class for unary cast operator classes.
86 class SCEVCastExpr : public SCEV {
87 protected:
88   const SCEV *Op;
89   Type *Ty;
90 
91   SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
92                Type *ty);
93 
94 public:
95   const SCEV *getOperand() const { return Op; }
96   const SCEV *getOperand(unsigned i) const {
97     assert(i == 0 && "Operand index out of range!");
98     return Op;
99   }
100   ArrayRef<const SCEV *> operands() const { return Op; }
101   size_t getNumOperands() const { return 1; }
102   Type *getType() const { return Ty; }
103 
104   /// Methods for support type inquiry through isa, cast, and dyn_cast:
105   static bool classof(const SCEV *S) {
106     return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
107            S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
108   }
109 };
110 
111 /// This class represents a cast from a pointer to a pointer-sized integer
112 /// value.
113 class SCEVPtrToIntExpr : public SCEVCastExpr {
114   friend class ScalarEvolution;
115 
116   SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
117 
118 public:
119   /// Methods for support type inquiry through isa, cast, and dyn_cast:
120   static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
121 };
122 
123 /// This is the base class for unary integral cast operator classes.
124 class SCEVIntegralCastExpr : public SCEVCastExpr {
125 protected:
126   SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
127                        const SCEV *op, Type *ty);
128 
129 public:
130   /// Methods for support type inquiry through isa, cast, and dyn_cast:
131   static bool classof(const SCEV *S) {
132     return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
133            S->getSCEVType() == scSignExtend;
134   }
135 };
136 
137 /// This class represents a truncation of an integer value to a
138 /// smaller integer value.
139 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
140   friend class ScalarEvolution;
141 
142   SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
143 
144 public:
145   /// Methods for support type inquiry through isa, cast, and dyn_cast:
146   static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
147 };
148 
149 /// This class represents a zero extension of a small integer value
150 /// to a larger integer value.
151 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
152   friend class ScalarEvolution;
153 
154   SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
155 
156 public:
157   /// Methods for support type inquiry through isa, cast, and dyn_cast:
158   static bool classof(const SCEV *S) {
159     return S->getSCEVType() == scZeroExtend;
160   }
161 };
162 
163 /// This class represents a sign extension of a small integer value
164 /// to a larger integer value.
165 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
166   friend class ScalarEvolution;
167 
168   SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
169 
170 public:
171   /// Methods for support type inquiry through isa, cast, and dyn_cast:
172   static bool classof(const SCEV *S) {
173     return S->getSCEVType() == scSignExtend;
174   }
175 };
176 
177 /// This node is a base class providing common functionality for
178 /// n'ary operators.
179 class SCEVNAryExpr : public SCEV {
180 protected:
181   // Since SCEVs are immutable, ScalarEvolution allocates operand
182   // arrays with its SCEVAllocator, so this class just needs a simple
183   // pointer rather than a more elaborate vector-like data structure.
184   // This also avoids the need for a non-trivial destructor.
185   const SCEV *const *Operands;
186   size_t NumOperands;
187 
188   SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
189                const SCEV *const *O, size_t N)
190       : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
191         NumOperands(N) {}
192 
193 public:
194   size_t getNumOperands() const { return NumOperands; }
195 
196   const SCEV *getOperand(unsigned i) const {
197     assert(i < NumOperands && "Operand index out of range!");
198     return Operands[i];
199   }
200 
201   ArrayRef<const SCEV *> operands() const {
202     return ArrayRef(Operands, NumOperands);
203   }
204 
205   NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
206     return (NoWrapFlags)(SubclassData & Mask);
207   }
208 
209   bool hasNoUnsignedWrap() const {
210     return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
211   }
212 
213   bool hasNoSignedWrap() const {
214     return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
215   }
216 
217   bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
218 
219   /// Methods for support type inquiry through isa, cast, and dyn_cast:
220   static bool classof(const SCEV *S) {
221     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
222            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
223            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
224            S->getSCEVType() == scSequentialUMinExpr ||
225            S->getSCEVType() == scAddRecExpr;
226   }
227 };
228 
229 /// This node is the base class for n'ary commutative operators.
230 class SCEVCommutativeExpr : public SCEVNAryExpr {
231 protected:
232   SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
233                       const SCEV *const *O, size_t N)
234       : SCEVNAryExpr(ID, T, O, N) {}
235 
236 public:
237   /// Methods for support type inquiry through isa, cast, and dyn_cast:
238   static bool classof(const SCEV *S) {
239     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
240            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
241            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
242   }
243 
244   /// Set flags for a non-recurrence without clearing previously set flags.
245   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
246 };
247 
248 /// This node represents an addition of some number of SCEVs.
249 class SCEVAddExpr : public SCEVCommutativeExpr {
250   friend class ScalarEvolution;
251 
252   Type *Ty;
253 
254   SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
255       : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
256     auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
257       return Op->getType()->isPointerTy();
258     });
259     if (FirstPointerTypedOp != operands().end())
260       Ty = (*FirstPointerTypedOp)->getType();
261     else
262       Ty = getOperand(0)->getType();
263   }
264 
265 public:
266   Type *getType() const { return Ty; }
267 
268   /// Methods for support type inquiry through isa, cast, and dyn_cast:
269   static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
270 };
271 
272 /// This node represents multiplication of some number of SCEVs.
273 class SCEVMulExpr : public SCEVCommutativeExpr {
274   friend class ScalarEvolution;
275 
276   SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
277       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
278 
279 public:
280   Type *getType() const { return getOperand(0)->getType(); }
281 
282   /// Methods for support type inquiry through isa, cast, and dyn_cast:
283   static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
284 };
285 
286 /// This class represents a binary unsigned division operation.
287 class SCEVUDivExpr : public SCEV {
288   friend class ScalarEvolution;
289 
290   std::array<const SCEV *, 2> Operands;
291 
292   SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
293       : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
294     Operands[0] = lhs;
295     Operands[1] = rhs;
296   }
297 
298 public:
299   const SCEV *getLHS() const { return Operands[0]; }
300   const SCEV *getRHS() const { return Operands[1]; }
301   size_t getNumOperands() const { return 2; }
302   const SCEV *getOperand(unsigned i) const {
303     assert((i == 0 || i == 1) && "Operand index out of range!");
304     return i == 0 ? getLHS() : getRHS();
305   }
306 
307   ArrayRef<const SCEV *> operands() const { return Operands; }
308 
309   Type *getType() const {
310     // In most cases the types of LHS and RHS will be the same, but in some
311     // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
312     // depend on the type for correctness, but handling types carefully can
313     // avoid extra casts in the SCEVExpander. The LHS is more likely to be
314     // a pointer type than the RHS, so use the RHS' type here.
315     return getRHS()->getType();
316   }
317 
318   /// Methods for support type inquiry through isa, cast, and dyn_cast:
319   static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
320 };
321 
322 /// This node represents a polynomial recurrence on the trip count
323 /// of the specified loop.  This is the primary focus of the
324 /// ScalarEvolution framework; all the other SCEV subclasses are
325 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
326 /// expressions to be created and analyzed.
327 ///
328 /// All operands of an AddRec are required to be loop invariant.
329 ///
330 class SCEVAddRecExpr : public SCEVNAryExpr {
331   friend class ScalarEvolution;
332 
333   const Loop *L;
334 
335   SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
336                  const Loop *l)
337       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
338 
339 public:
340   Type *getType() const { return getStart()->getType(); }
341   const SCEV *getStart() const { return Operands[0]; }
342   const Loop *getLoop() const { return L; }
343 
344   /// Constructs and returns the recurrence indicating how much this
345   /// expression steps by.  If this is a polynomial of degree N, it
346   /// returns a chrec of degree N-1.  We cannot determine whether
347   /// the step recurrence has self-wraparound.
348   const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
349     if (isAffine())
350       return getOperand(1);
351     return SE.getAddRecExpr(
352         SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
353         FlagAnyWrap);
354   }
355 
356   /// Return true if this represents an expression A + B*x where A
357   /// and B are loop invariant values.
358   bool isAffine() const {
359     // We know that the start value is invariant.  This expression is thus
360     // affine iff the step is also invariant.
361     return getNumOperands() == 2;
362   }
363 
364   /// Return true if this represents an expression A + B*x + C*x^2
365   /// where A, B and C are loop invariant values.  This corresponds
366   /// to an addrec of the form {L,+,M,+,N}
367   bool isQuadratic() const { return getNumOperands() == 3; }
368 
369   /// Set flags for a recurrence without clearing any previously set flags.
370   /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
371   /// to make it easier to propagate flags.
372   void setNoWrapFlags(NoWrapFlags Flags) {
373     if (Flags & (FlagNUW | FlagNSW))
374       Flags = ScalarEvolution::setFlags(Flags, FlagNW);
375     SubclassData |= Flags;
376   }
377 
378   /// Return the value of this chain of recurrences at the specified
379   /// iteration number.
380   const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
381 
382   /// Return the value of this chain of recurrences at the specified iteration
383   /// number. Takes an explicit list of operands to represent an AddRec.
384   static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
385                                          const SCEV *It, ScalarEvolution &SE);
386 
387   /// Return the number of iterations of this loop that produce
388   /// values in the specified constant range.  Another way of
389   /// looking at this is that it returns the first iteration number
390   /// where the value is not in the condition, thus computing the
391   /// exit count.  If the iteration count can't be computed, an
392   /// instance of SCEVCouldNotCompute is returned.
393   const SCEV *getNumIterationsInRange(const ConstantRange &Range,
394                                       ScalarEvolution &SE) const;
395 
396   /// Return an expression representing the value of this expression
397   /// one iteration of the loop ahead.
398   const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
399 
400   /// Methods for support type inquiry through isa, cast, and dyn_cast:
401   static bool classof(const SCEV *S) {
402     return S->getSCEVType() == scAddRecExpr;
403   }
404 };
405 
406 /// This node is the base class min/max selections.
407 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
408   friend class ScalarEvolution;
409 
410   static bool isMinMaxType(enum SCEVTypes T) {
411     return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
412            T == scUMinExpr;
413   }
414 
415 protected:
416   /// Note: Constructing subclasses via this constructor is allowed
417   SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
418                  const SCEV *const *O, size_t N)
419       : SCEVCommutativeExpr(ID, T, O, N) {
420     assert(isMinMaxType(T));
421     // Min and max never overflow
422     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
423   }
424 
425 public:
426   Type *getType() const { return getOperand(0)->getType(); }
427 
428   static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
429 
430   static enum SCEVTypes negate(enum SCEVTypes T) {
431     switch (T) {
432     case scSMaxExpr:
433       return scSMinExpr;
434     case scSMinExpr:
435       return scSMaxExpr;
436     case scUMaxExpr:
437       return scUMinExpr;
438     case scUMinExpr:
439       return scUMaxExpr;
440     default:
441       llvm_unreachable("Not a min or max SCEV type!");
442     }
443   }
444 };
445 
446 /// This class represents a signed maximum selection.
447 class SCEVSMaxExpr : public SCEVMinMaxExpr {
448   friend class ScalarEvolution;
449 
450   SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
451       : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
452 
453 public:
454   /// Methods for support type inquiry through isa, cast, and dyn_cast:
455   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
456 };
457 
458 /// This class represents an unsigned maximum selection.
459 class SCEVUMaxExpr : public SCEVMinMaxExpr {
460   friend class ScalarEvolution;
461 
462   SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
463       : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
464 
465 public:
466   /// Methods for support type inquiry through isa, cast, and dyn_cast:
467   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
468 };
469 
470 /// This class represents a signed minimum selection.
471 class SCEVSMinExpr : public SCEVMinMaxExpr {
472   friend class ScalarEvolution;
473 
474   SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
475       : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
476 
477 public:
478   /// Methods for support type inquiry through isa, cast, and dyn_cast:
479   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
480 };
481 
482 /// This class represents an unsigned minimum selection.
483 class SCEVUMinExpr : public SCEVMinMaxExpr {
484   friend class ScalarEvolution;
485 
486   SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
487       : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
488 
489 public:
490   /// Methods for support type inquiry through isa, cast, and dyn_cast:
491   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
492 };
493 
494 /// This node is the base class for sequential/in-order min/max selections.
495 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
496 /// are early-returning upon reaching saturation point.
497 /// I.e. given `0 umin_seq poison`, the result will be `0`,
498 /// while the result of `0 umin poison` is `poison`.
499 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
500   friend class ScalarEvolution;
501 
502   static bool isSequentialMinMaxType(enum SCEVTypes T) {
503     return T == scSequentialUMinExpr;
504   }
505 
506   /// Set flags for a non-recurrence without clearing previously set flags.
507   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
508 
509 protected:
510   /// Note: Constructing subclasses via this constructor is allowed
511   SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
512                            const SCEV *const *O, size_t N)
513       : SCEVNAryExpr(ID, T, O, N) {
514     assert(isSequentialMinMaxType(T));
515     // Min and max never overflow
516     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
517   }
518 
519 public:
520   Type *getType() const { return getOperand(0)->getType(); }
521 
522   static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
523     assert(isSequentialMinMaxType(Ty));
524     switch (Ty) {
525     case scSequentialUMinExpr:
526       return scUMinExpr;
527     default:
528       llvm_unreachable("Not a sequential min/max type.");
529     }
530   }
531 
532   SCEVTypes getEquivalentNonSequentialSCEVType() const {
533     return getEquivalentNonSequentialSCEVType(getSCEVType());
534   }
535 
536   static bool classof(const SCEV *S) {
537     return isSequentialMinMaxType(S->getSCEVType());
538   }
539 };
540 
541 /// This class represents a sequential/in-order unsigned minimum selection.
542 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
543   friend class ScalarEvolution;
544 
545   SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
546                          size_t N)
547       : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
548 
549 public:
550   /// Methods for support type inquiry through isa, cast, and dyn_cast:
551   static bool classof(const SCEV *S) {
552     return S->getSCEVType() == scSequentialUMinExpr;
553   }
554 };
555 
556 /// This means that we are dealing with an entirely unknown SCEV
557 /// value, and only represent it as its LLVM Value.  This is the
558 /// "bottom" value for the analysis.
559 class SCEVUnknown final : public SCEV, private CallbackVH {
560   friend class ScalarEvolution;
561 
562   /// The parent ScalarEvolution value. This is used to update the
563   /// parent's maps when the value associated with a SCEVUnknown is
564   /// deleted or RAUW'd.
565   ScalarEvolution *SE;
566 
567   /// The next pointer in the linked list of all SCEVUnknown
568   /// instances owned by a ScalarEvolution.
569   SCEVUnknown *Next;
570 
571   SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
572               SCEVUnknown *next)
573       : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
574 
575   // Implement CallbackVH.
576   void deleted() override;
577   void allUsesReplacedWith(Value *New) override;
578 
579 public:
580   Value *getValue() const { return getValPtr(); }
581 
582   /// @{
583   /// Test whether this is a special constant representing a type
584   /// size, alignment, or field offset in a target-independent
585   /// manner, and hasn't happened to have been folded with other
586   /// operations into something unrecognizable. This is mainly only
587   /// useful for pretty-printing and other situations where it isn't
588   /// absolutely required for these to succeed.
589   bool isSizeOf(Type *&AllocTy) const;
590   bool isAlignOf(Type *&AllocTy) const;
591   bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
592   /// @}
593 
594   Type *getType() const { return getValPtr()->getType(); }
595 
596   /// Methods for support type inquiry through isa, cast, and dyn_cast:
597   static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
598 };
599 
600 /// This class defines a simple visitor class that may be used for
601 /// various SCEV analysis purposes.
602 template <typename SC, typename RetVal = void> struct SCEVVisitor {
603   RetVal visit(const SCEV *S) {
604     switch (S->getSCEVType()) {
605     case scConstant:
606       return ((SC *)this)->visitConstant((const SCEVConstant *)S);
607     case scPtrToInt:
608       return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
609     case scTruncate:
610       return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
611     case scZeroExtend:
612       return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
613     case scSignExtend:
614       return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
615     case scAddExpr:
616       return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
617     case scMulExpr:
618       return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
619     case scUDivExpr:
620       return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
621     case scAddRecExpr:
622       return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
623     case scSMaxExpr:
624       return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
625     case scUMaxExpr:
626       return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
627     case scSMinExpr:
628       return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
629     case scUMinExpr:
630       return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
631     case scSequentialUMinExpr:
632       return ((SC *)this)
633           ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
634     case scUnknown:
635       return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
636     case scCouldNotCompute:
637       return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
638     }
639     llvm_unreachable("Unknown SCEV kind!");
640   }
641 
642   RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
643     llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
644   }
645 };
646 
647 /// Visit all nodes in the expression tree using worklist traversal.
648 ///
649 /// Visitor implements:
650 ///   // return true to follow this node.
651 ///   bool follow(const SCEV *S);
652 ///   // return true to terminate the search.
653 ///   bool isDone();
654 template <typename SV> class SCEVTraversal {
655   SV &Visitor;
656   SmallVector<const SCEV *, 8> Worklist;
657   SmallPtrSet<const SCEV *, 8> Visited;
658 
659   void push(const SCEV *S) {
660     if (Visited.insert(S).second && Visitor.follow(S))
661       Worklist.push_back(S);
662   }
663 
664 public:
665   SCEVTraversal(SV &V) : Visitor(V) {}
666 
667   void visitAll(const SCEV *Root) {
668     push(Root);
669     while (!Worklist.empty() && !Visitor.isDone()) {
670       const SCEV *S = Worklist.pop_back_val();
671 
672       switch (S->getSCEVType()) {
673       case scConstant:
674       case scUnknown:
675         continue;
676       case scPtrToInt:
677       case scTruncate:
678       case scZeroExtend:
679       case scSignExtend:
680       case scAddExpr:
681       case scMulExpr:
682       case scUDivExpr:
683       case scSMaxExpr:
684       case scUMaxExpr:
685       case scSMinExpr:
686       case scUMinExpr:
687       case scSequentialUMinExpr:
688       case scAddRecExpr:
689         for (const auto *Op : S->operands()) {
690           push(Op);
691           if (Visitor.isDone())
692             break;
693         }
694         continue;
695       case scCouldNotCompute:
696         llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
697       }
698       llvm_unreachable("Unknown SCEV kind!");
699     }
700   }
701 };
702 
703 /// Use SCEVTraversal to visit all nodes in the given expression tree.
704 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
705   SCEVTraversal<SV> T(Visitor);
706   T.visitAll(Root);
707 }
708 
709 /// Return true if any node in \p Root satisfies the predicate \p Pred.
710 template <typename PredTy>
711 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
712   struct FindClosure {
713     bool Found = false;
714     PredTy Pred;
715 
716     FindClosure(PredTy Pred) : Pred(Pred) {}
717 
718     bool follow(const SCEV *S) {
719       if (!Pred(S))
720         return true;
721 
722       Found = true;
723       return false;
724     }
725 
726     bool isDone() const { return Found; }
727   };
728 
729   FindClosure FC(Pred);
730   visitAll(Root, FC);
731   return FC.Found;
732 }
733 
734 /// This visitor recursively visits a SCEV expression and re-writes it.
735 /// The result from each visit is cached, so it will return the same
736 /// SCEV for the same input.
737 template <typename SC>
738 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
739 protected:
740   ScalarEvolution &SE;
741   // Memoize the result of each visit so that we only compute once for
742   // the same input SCEV. This is to avoid redundant computations when
743   // a SCEV is referenced by multiple SCEVs. Without memoization, this
744   // visit algorithm would have exponential time complexity in the worst
745   // case, causing the compiler to hang on certain tests.
746   DenseMap<const SCEV *, const SCEV *> RewriteResults;
747 
748 public:
749   SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
750 
751   const SCEV *visit(const SCEV *S) {
752     auto It = RewriteResults.find(S);
753     if (It != RewriteResults.end())
754       return It->second;
755     auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
756     auto Result = RewriteResults.try_emplace(S, Visited);
757     assert(Result.second && "Should insert a new entry");
758     return Result.first->second;
759   }
760 
761   const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
762 
763   const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
764     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
765     return Operand == Expr->getOperand()
766                ? Expr
767                : SE.getPtrToIntExpr(Operand, Expr->getType());
768   }
769 
770   const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
771     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
772     return Operand == Expr->getOperand()
773                ? Expr
774                : SE.getTruncateExpr(Operand, Expr->getType());
775   }
776 
777   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
778     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
779     return Operand == Expr->getOperand()
780                ? Expr
781                : SE.getZeroExtendExpr(Operand, Expr->getType());
782   }
783 
784   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
785     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
786     return Operand == Expr->getOperand()
787                ? Expr
788                : SE.getSignExtendExpr(Operand, Expr->getType());
789   }
790 
791   const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
792     SmallVector<const SCEV *, 2> Operands;
793     bool Changed = false;
794     for (const auto *Op : Expr->operands()) {
795       Operands.push_back(((SC *)this)->visit(Op));
796       Changed |= Op != Operands.back();
797     }
798     return !Changed ? Expr : SE.getAddExpr(Operands);
799   }
800 
801   const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
802     SmallVector<const SCEV *, 2> Operands;
803     bool Changed = false;
804     for (const auto *Op : Expr->operands()) {
805       Operands.push_back(((SC *)this)->visit(Op));
806       Changed |= Op != Operands.back();
807     }
808     return !Changed ? Expr : SE.getMulExpr(Operands);
809   }
810 
811   const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
812     auto *LHS = ((SC *)this)->visit(Expr->getLHS());
813     auto *RHS = ((SC *)this)->visit(Expr->getRHS());
814     bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
815     return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
816   }
817 
818   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
819     SmallVector<const SCEV *, 2> Operands;
820     bool Changed = false;
821     for (const auto *Op : Expr->operands()) {
822       Operands.push_back(((SC *)this)->visit(Op));
823       Changed |= Op != Operands.back();
824     }
825     return !Changed ? Expr
826                     : SE.getAddRecExpr(Operands, Expr->getLoop(),
827                                        Expr->getNoWrapFlags());
828   }
829 
830   const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
831     SmallVector<const SCEV *, 2> Operands;
832     bool Changed = false;
833     for (const auto *Op : Expr->operands()) {
834       Operands.push_back(((SC *)this)->visit(Op));
835       Changed |= Op != Operands.back();
836     }
837     return !Changed ? Expr : SE.getSMaxExpr(Operands);
838   }
839 
840   const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
841     SmallVector<const SCEV *, 2> Operands;
842     bool Changed = false;
843     for (const auto *Op : Expr->operands()) {
844       Operands.push_back(((SC *)this)->visit(Op));
845       Changed |= Op != Operands.back();
846     }
847     return !Changed ? Expr : SE.getUMaxExpr(Operands);
848   }
849 
850   const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
851     SmallVector<const SCEV *, 2> Operands;
852     bool Changed = false;
853     for (const auto *Op : Expr->operands()) {
854       Operands.push_back(((SC *)this)->visit(Op));
855       Changed |= Op != Operands.back();
856     }
857     return !Changed ? Expr : SE.getSMinExpr(Operands);
858   }
859 
860   const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
861     SmallVector<const SCEV *, 2> Operands;
862     bool Changed = false;
863     for (const auto *Op : Expr->operands()) {
864       Operands.push_back(((SC *)this)->visit(Op));
865       Changed |= Op != Operands.back();
866     }
867     return !Changed ? Expr : SE.getUMinExpr(Operands);
868   }
869 
870   const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
871     SmallVector<const SCEV *, 2> Operands;
872     bool Changed = false;
873     for (const auto *Op : Expr->operands()) {
874       Operands.push_back(((SC *)this)->visit(Op));
875       Changed |= Op != Operands.back();
876     }
877     return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
878   }
879 
880   const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
881 
882   const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
883     return Expr;
884   }
885 };
886 
887 using ValueToValueMap = DenseMap<const Value *, Value *>;
888 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
889 
890 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
891 /// the SCEVUnknown components following the Map (Value -> SCEV).
892 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
893 public:
894   static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
895                              ValueToSCEVMapTy &Map) {
896     SCEVParameterRewriter Rewriter(SE, Map);
897     return Rewriter.visit(Scev);
898   }
899 
900   SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
901       : SCEVRewriteVisitor(SE), Map(M) {}
902 
903   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
904     auto I = Map.find(Expr->getValue());
905     if (I == Map.end())
906       return Expr;
907     return I->second;
908   }
909 
910 private:
911   ValueToSCEVMapTy &Map;
912 };
913 
914 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
915 
916 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
917 /// the Map (Loop -> SCEV) to all AddRecExprs.
918 class SCEVLoopAddRecRewriter
919     : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
920 public:
921   SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
922       : SCEVRewriteVisitor(SE), Map(M) {}
923 
924   static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
925                              ScalarEvolution &SE) {
926     SCEVLoopAddRecRewriter Rewriter(SE, Map);
927     return Rewriter.visit(Scev);
928   }
929 
930   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
931     SmallVector<const SCEV *, 2> Operands;
932     for (const SCEV *Op : Expr->operands())
933       Operands.push_back(visit(Op));
934 
935     const Loop *L = Expr->getLoop();
936     if (0 == Map.count(L))
937       return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
938 
939     return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
940   }
941 
942 private:
943   LoopToScevMapT &Map;
944 };
945 
946 } // end namespace llvm
947 
948 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
949