1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the classes used to represent and build scalar expressions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/SmallPtrSet.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/iterator_range.h" 20 #include "llvm/Analysis/ScalarEvolution.h" 21 #include "llvm/IR/Constants.h" 22 #include "llvm/IR/ValueHandle.h" 23 #include "llvm/Support/Casting.h" 24 #include "llvm/Support/ErrorHandling.h" 25 #include <cassert> 26 #include <cstddef> 27 28 namespace llvm { 29 30 class APInt; 31 class Constant; 32 class ConstantInt; 33 class ConstantRange; 34 class Loop; 35 class Type; 36 class Value; 37 38 enum SCEVTypes : unsigned short { 39 // These should be ordered in terms of increasing complexity to make the 40 // folders simpler. 41 scConstant, 42 scTruncate, 43 scZeroExtend, 44 scSignExtend, 45 scAddExpr, 46 scMulExpr, 47 scUDivExpr, 48 scAddRecExpr, 49 scUMaxExpr, 50 scSMaxExpr, 51 scUMinExpr, 52 scSMinExpr, 53 scSequentialUMinExpr, 54 scPtrToInt, 55 scUnknown, 56 scCouldNotCompute 57 }; 58 59 /// This class represents a constant integer value. 60 class SCEVConstant : public SCEV { 61 friend class ScalarEvolution; 62 63 ConstantInt *V; 64 65 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) 66 : SCEV(ID, scConstant, 1), V(v) {} 67 68 public: 69 ConstantInt *getValue() const { return V; } 70 const APInt &getAPInt() const { return getValue()->getValue(); } 71 72 Type *getType() const { return V->getType(); } 73 74 /// Methods for support type inquiry through isa, cast, and dyn_cast: 75 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } 76 }; 77 78 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) { 79 APInt Size(16, 1); 80 for (const auto *Arg : Args) 81 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize())); 82 return (unsigned short)Size.getZExtValue(); 83 } 84 85 /// This is the base class for unary cast operator classes. 86 class SCEVCastExpr : public SCEV { 87 protected: 88 const SCEV *Op; 89 Type *Ty; 90 91 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, 92 Type *ty); 93 94 public: 95 const SCEV *getOperand() const { return Op; } 96 const SCEV *getOperand(unsigned i) const { 97 assert(i == 0 && "Operand index out of range!"); 98 return Op; 99 } 100 ArrayRef<const SCEV *> operands() const { return Op; } 101 size_t getNumOperands() const { return 1; } 102 Type *getType() const { return Ty; } 103 104 /// Methods for support type inquiry through isa, cast, and dyn_cast: 105 static bool classof(const SCEV *S) { 106 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate || 107 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; 108 } 109 }; 110 111 /// This class represents a cast from a pointer to a pointer-sized integer 112 /// value. 113 class SCEVPtrToIntExpr : public SCEVCastExpr { 114 friend class ScalarEvolution; 115 116 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy); 117 118 public: 119 /// Methods for support type inquiry through isa, cast, and dyn_cast: 120 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; } 121 }; 122 123 /// This is the base class for unary integral cast operator classes. 124 class SCEVIntegralCastExpr : public SCEVCastExpr { 125 protected: 126 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, 127 const SCEV *op, Type *ty); 128 129 public: 130 /// Methods for support type inquiry through isa, cast, and dyn_cast: 131 static bool classof(const SCEV *S) { 132 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || 133 S->getSCEVType() == scSignExtend; 134 } 135 }; 136 137 /// This class represents a truncation of an integer value to a 138 /// smaller integer value. 139 class SCEVTruncateExpr : public SCEVIntegralCastExpr { 140 friend class ScalarEvolution; 141 142 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 143 144 public: 145 /// Methods for support type inquiry through isa, cast, and dyn_cast: 146 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } 147 }; 148 149 /// This class represents a zero extension of a small integer value 150 /// to a larger integer value. 151 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { 152 friend class ScalarEvolution; 153 154 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 155 156 public: 157 /// Methods for support type inquiry through isa, cast, and dyn_cast: 158 static bool classof(const SCEV *S) { 159 return S->getSCEVType() == scZeroExtend; 160 } 161 }; 162 163 /// This class represents a sign extension of a small integer value 164 /// to a larger integer value. 165 class SCEVSignExtendExpr : public SCEVIntegralCastExpr { 166 friend class ScalarEvolution; 167 168 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 169 170 public: 171 /// Methods for support type inquiry through isa, cast, and dyn_cast: 172 static bool classof(const SCEV *S) { 173 return S->getSCEVType() == scSignExtend; 174 } 175 }; 176 177 /// This node is a base class providing common functionality for 178 /// n'ary operators. 179 class SCEVNAryExpr : public SCEV { 180 protected: 181 // Since SCEVs are immutable, ScalarEvolution allocates operand 182 // arrays with its SCEVAllocator, so this class just needs a simple 183 // pointer rather than a more elaborate vector-like data structure. 184 // This also avoids the need for a non-trivial destructor. 185 const SCEV *const *Operands; 186 size_t NumOperands; 187 188 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 189 const SCEV *const *O, size_t N) 190 : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O), 191 NumOperands(N) {} 192 193 public: 194 size_t getNumOperands() const { return NumOperands; } 195 196 const SCEV *getOperand(unsigned i) const { 197 assert(i < NumOperands && "Operand index out of range!"); 198 return Operands[i]; 199 } 200 201 ArrayRef<const SCEV *> operands() const { 202 return ArrayRef(Operands, NumOperands); 203 } 204 205 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { 206 return (NoWrapFlags)(SubclassData & Mask); 207 } 208 209 bool hasNoUnsignedWrap() const { 210 return getNoWrapFlags(FlagNUW) != FlagAnyWrap; 211 } 212 213 bool hasNoSignedWrap() const { 214 return getNoWrapFlags(FlagNSW) != FlagAnyWrap; 215 } 216 217 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; } 218 219 /// Methods for support type inquiry through isa, cast, and dyn_cast: 220 static bool classof(const SCEV *S) { 221 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 222 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 223 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || 224 S->getSCEVType() == scSequentialUMinExpr || 225 S->getSCEVType() == scAddRecExpr; 226 } 227 }; 228 229 /// This node is the base class for n'ary commutative operators. 230 class SCEVCommutativeExpr : public SCEVNAryExpr { 231 protected: 232 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 233 const SCEV *const *O, size_t N) 234 : SCEVNAryExpr(ID, T, O, N) {} 235 236 public: 237 /// Methods for support type inquiry through isa, cast, and dyn_cast: 238 static bool classof(const SCEV *S) { 239 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 240 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 241 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; 242 } 243 244 /// Set flags for a non-recurrence without clearing previously set flags. 245 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 246 }; 247 248 /// This node represents an addition of some number of SCEVs. 249 class SCEVAddExpr : public SCEVCommutativeExpr { 250 friend class ScalarEvolution; 251 252 Type *Ty; 253 254 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 255 : SCEVCommutativeExpr(ID, scAddExpr, O, N) { 256 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) { 257 return Op->getType()->isPointerTy(); 258 }); 259 if (FirstPointerTypedOp != operands().end()) 260 Ty = (*FirstPointerTypedOp)->getType(); 261 else 262 Ty = getOperand(0)->getType(); 263 } 264 265 public: 266 Type *getType() const { return Ty; } 267 268 /// Methods for support type inquiry through isa, cast, and dyn_cast: 269 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } 270 }; 271 272 /// This node represents multiplication of some number of SCEVs. 273 class SCEVMulExpr : public SCEVCommutativeExpr { 274 friend class ScalarEvolution; 275 276 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 277 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} 278 279 public: 280 Type *getType() const { return getOperand(0)->getType(); } 281 282 /// Methods for support type inquiry through isa, cast, and dyn_cast: 283 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } 284 }; 285 286 /// This class represents a binary unsigned division operation. 287 class SCEVUDivExpr : public SCEV { 288 friend class ScalarEvolution; 289 290 std::array<const SCEV *, 2> Operands; 291 292 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) 293 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) { 294 Operands[0] = lhs; 295 Operands[1] = rhs; 296 } 297 298 public: 299 const SCEV *getLHS() const { return Operands[0]; } 300 const SCEV *getRHS() const { return Operands[1]; } 301 size_t getNumOperands() const { return 2; } 302 const SCEV *getOperand(unsigned i) const { 303 assert((i == 0 || i == 1) && "Operand index out of range!"); 304 return i == 0 ? getLHS() : getRHS(); 305 } 306 307 ArrayRef<const SCEV *> operands() const { return Operands; } 308 309 Type *getType() const { 310 // In most cases the types of LHS and RHS will be the same, but in some 311 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't 312 // depend on the type for correctness, but handling types carefully can 313 // avoid extra casts in the SCEVExpander. The LHS is more likely to be 314 // a pointer type than the RHS, so use the RHS' type here. 315 return getRHS()->getType(); 316 } 317 318 /// Methods for support type inquiry through isa, cast, and dyn_cast: 319 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } 320 }; 321 322 /// This node represents a polynomial recurrence on the trip count 323 /// of the specified loop. This is the primary focus of the 324 /// ScalarEvolution framework; all the other SCEV subclasses are 325 /// mostly just supporting infrastructure to allow SCEVAddRecExpr 326 /// expressions to be created and analyzed. 327 /// 328 /// All operands of an AddRec are required to be loop invariant. 329 /// 330 class SCEVAddRecExpr : public SCEVNAryExpr { 331 friend class ScalarEvolution; 332 333 const Loop *L; 334 335 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N, 336 const Loop *l) 337 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} 338 339 public: 340 Type *getType() const { return getStart()->getType(); } 341 const SCEV *getStart() const { return Operands[0]; } 342 const Loop *getLoop() const { return L; } 343 344 /// Constructs and returns the recurrence indicating how much this 345 /// expression steps by. If this is a polynomial of degree N, it 346 /// returns a chrec of degree N-1. We cannot determine whether 347 /// the step recurrence has self-wraparound. 348 const SCEV *getStepRecurrence(ScalarEvolution &SE) const { 349 if (isAffine()) 350 return getOperand(1); 351 return SE.getAddRecExpr( 352 SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(), 353 FlagAnyWrap); 354 } 355 356 /// Return true if this represents an expression A + B*x where A 357 /// and B are loop invariant values. 358 bool isAffine() const { 359 // We know that the start value is invariant. This expression is thus 360 // affine iff the step is also invariant. 361 return getNumOperands() == 2; 362 } 363 364 /// Return true if this represents an expression A + B*x + C*x^2 365 /// where A, B and C are loop invariant values. This corresponds 366 /// to an addrec of the form {L,+,M,+,N} 367 bool isQuadratic() const { return getNumOperands() == 3; } 368 369 /// Set flags for a recurrence without clearing any previously set flags. 370 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here 371 /// to make it easier to propagate flags. 372 void setNoWrapFlags(NoWrapFlags Flags) { 373 if (Flags & (FlagNUW | FlagNSW)) 374 Flags = ScalarEvolution::setFlags(Flags, FlagNW); 375 SubclassData |= Flags; 376 } 377 378 /// Return the value of this chain of recurrences at the specified 379 /// iteration number. 380 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; 381 382 /// Return the value of this chain of recurrences at the specified iteration 383 /// number. Takes an explicit list of operands to represent an AddRec. 384 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands, 385 const SCEV *It, ScalarEvolution &SE); 386 387 /// Return the number of iterations of this loop that produce 388 /// values in the specified constant range. Another way of 389 /// looking at this is that it returns the first iteration number 390 /// where the value is not in the condition, thus computing the 391 /// exit count. If the iteration count can't be computed, an 392 /// instance of SCEVCouldNotCompute is returned. 393 const SCEV *getNumIterationsInRange(const ConstantRange &Range, 394 ScalarEvolution &SE) const; 395 396 /// Return an expression representing the value of this expression 397 /// one iteration of the loop ahead. 398 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const; 399 400 /// Methods for support type inquiry through isa, cast, and dyn_cast: 401 static bool classof(const SCEV *S) { 402 return S->getSCEVType() == scAddRecExpr; 403 } 404 }; 405 406 /// This node is the base class min/max selections. 407 class SCEVMinMaxExpr : public SCEVCommutativeExpr { 408 friend class ScalarEvolution; 409 410 static bool isMinMaxType(enum SCEVTypes T) { 411 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr || 412 T == scUMinExpr; 413 } 414 415 protected: 416 /// Note: Constructing subclasses via this constructor is allowed 417 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 418 const SCEV *const *O, size_t N) 419 : SCEVCommutativeExpr(ID, T, O, N) { 420 assert(isMinMaxType(T)); 421 // Min and max never overflow 422 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 423 } 424 425 public: 426 Type *getType() const { return getOperand(0)->getType(); } 427 428 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); } 429 430 static enum SCEVTypes negate(enum SCEVTypes T) { 431 switch (T) { 432 case scSMaxExpr: 433 return scSMinExpr; 434 case scSMinExpr: 435 return scSMaxExpr; 436 case scUMaxExpr: 437 return scUMinExpr; 438 case scUMinExpr: 439 return scUMaxExpr; 440 default: 441 llvm_unreachable("Not a min or max SCEV type!"); 442 } 443 } 444 }; 445 446 /// This class represents a signed maximum selection. 447 class SCEVSMaxExpr : public SCEVMinMaxExpr { 448 friend class ScalarEvolution; 449 450 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 451 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {} 452 453 public: 454 /// Methods for support type inquiry through isa, cast, and dyn_cast: 455 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } 456 }; 457 458 /// This class represents an unsigned maximum selection. 459 class SCEVUMaxExpr : public SCEVMinMaxExpr { 460 friend class ScalarEvolution; 461 462 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 463 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {} 464 465 public: 466 /// Methods for support type inquiry through isa, cast, and dyn_cast: 467 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } 468 }; 469 470 /// This class represents a signed minimum selection. 471 class SCEVSMinExpr : public SCEVMinMaxExpr { 472 friend class ScalarEvolution; 473 474 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 475 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {} 476 477 public: 478 /// Methods for support type inquiry through isa, cast, and dyn_cast: 479 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; } 480 }; 481 482 /// This class represents an unsigned minimum selection. 483 class SCEVUMinExpr : public SCEVMinMaxExpr { 484 friend class ScalarEvolution; 485 486 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 487 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {} 488 489 public: 490 /// Methods for support type inquiry through isa, cast, and dyn_cast: 491 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; } 492 }; 493 494 /// This node is the base class for sequential/in-order min/max selections. 495 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they 496 /// are early-returning upon reaching saturation point. 497 /// I.e. given `0 umin_seq poison`, the result will be `0`, 498 /// while the result of `0 umin poison` is `poison`. 499 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { 500 friend class ScalarEvolution; 501 502 static bool isSequentialMinMaxType(enum SCEVTypes T) { 503 return T == scSequentialUMinExpr; 504 } 505 506 /// Set flags for a non-recurrence without clearing previously set flags. 507 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 508 509 protected: 510 /// Note: Constructing subclasses via this constructor is allowed 511 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 512 const SCEV *const *O, size_t N) 513 : SCEVNAryExpr(ID, T, O, N) { 514 assert(isSequentialMinMaxType(T)); 515 // Min and max never overflow 516 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 517 } 518 519 public: 520 Type *getType() const { return getOperand(0)->getType(); } 521 522 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) { 523 assert(isSequentialMinMaxType(Ty)); 524 switch (Ty) { 525 case scSequentialUMinExpr: 526 return scUMinExpr; 527 default: 528 llvm_unreachable("Not a sequential min/max type."); 529 } 530 } 531 532 SCEVTypes getEquivalentNonSequentialSCEVType() const { 533 return getEquivalentNonSequentialSCEVType(getSCEVType()); 534 } 535 536 static bool classof(const SCEV *S) { 537 return isSequentialMinMaxType(S->getSCEVType()); 538 } 539 }; 540 541 /// This class represents a sequential/in-order unsigned minimum selection. 542 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { 543 friend class ScalarEvolution; 544 545 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, 546 size_t N) 547 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} 548 549 public: 550 /// Methods for support type inquiry through isa, cast, and dyn_cast: 551 static bool classof(const SCEV *S) { 552 return S->getSCEVType() == scSequentialUMinExpr; 553 } 554 }; 555 556 /// This means that we are dealing with an entirely unknown SCEV 557 /// value, and only represent it as its LLVM Value. This is the 558 /// "bottom" value for the analysis. 559 class SCEVUnknown final : public SCEV, private CallbackVH { 560 friend class ScalarEvolution; 561 562 /// The parent ScalarEvolution value. This is used to update the 563 /// parent's maps when the value associated with a SCEVUnknown is 564 /// deleted or RAUW'd. 565 ScalarEvolution *SE; 566 567 /// The next pointer in the linked list of all SCEVUnknown 568 /// instances owned by a ScalarEvolution. 569 SCEVUnknown *Next; 570 571 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se, 572 SCEVUnknown *next) 573 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {} 574 575 // Implement CallbackVH. 576 void deleted() override; 577 void allUsesReplacedWith(Value *New) override; 578 579 public: 580 Value *getValue() const { return getValPtr(); } 581 582 /// @{ 583 /// Test whether this is a special constant representing a type 584 /// size, alignment, or field offset in a target-independent 585 /// manner, and hasn't happened to have been folded with other 586 /// operations into something unrecognizable. This is mainly only 587 /// useful for pretty-printing and other situations where it isn't 588 /// absolutely required for these to succeed. 589 bool isSizeOf(Type *&AllocTy) const; 590 bool isAlignOf(Type *&AllocTy) const; 591 bool isOffsetOf(Type *&STy, Constant *&FieldNo) const; 592 /// @} 593 594 Type *getType() const { return getValPtr()->getType(); } 595 596 /// Methods for support type inquiry through isa, cast, and dyn_cast: 597 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } 598 }; 599 600 /// This class defines a simple visitor class that may be used for 601 /// various SCEV analysis purposes. 602 template <typename SC, typename RetVal = void> struct SCEVVisitor { 603 RetVal visit(const SCEV *S) { 604 switch (S->getSCEVType()) { 605 case scConstant: 606 return ((SC *)this)->visitConstant((const SCEVConstant *)S); 607 case scPtrToInt: 608 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); 609 case scTruncate: 610 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S); 611 case scZeroExtend: 612 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S); 613 case scSignExtend: 614 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S); 615 case scAddExpr: 616 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S); 617 case scMulExpr: 618 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S); 619 case scUDivExpr: 620 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S); 621 case scAddRecExpr: 622 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S); 623 case scSMaxExpr: 624 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S); 625 case scUMaxExpr: 626 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S); 627 case scSMinExpr: 628 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); 629 case scUMinExpr: 630 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); 631 case scSequentialUMinExpr: 632 return ((SC *)this) 633 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); 634 case scUnknown: 635 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); 636 case scCouldNotCompute: 637 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S); 638 } 639 llvm_unreachable("Unknown SCEV kind!"); 640 } 641 642 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { 643 llvm_unreachable("Invalid use of SCEVCouldNotCompute!"); 644 } 645 }; 646 647 /// Visit all nodes in the expression tree using worklist traversal. 648 /// 649 /// Visitor implements: 650 /// // return true to follow this node. 651 /// bool follow(const SCEV *S); 652 /// // return true to terminate the search. 653 /// bool isDone(); 654 template <typename SV> class SCEVTraversal { 655 SV &Visitor; 656 SmallVector<const SCEV *, 8> Worklist; 657 SmallPtrSet<const SCEV *, 8> Visited; 658 659 void push(const SCEV *S) { 660 if (Visited.insert(S).second && Visitor.follow(S)) 661 Worklist.push_back(S); 662 } 663 664 public: 665 SCEVTraversal(SV &V) : Visitor(V) {} 666 667 void visitAll(const SCEV *Root) { 668 push(Root); 669 while (!Worklist.empty() && !Visitor.isDone()) { 670 const SCEV *S = Worklist.pop_back_val(); 671 672 switch (S->getSCEVType()) { 673 case scConstant: 674 case scUnknown: 675 continue; 676 case scPtrToInt: 677 case scTruncate: 678 case scZeroExtend: 679 case scSignExtend: 680 case scAddExpr: 681 case scMulExpr: 682 case scUDivExpr: 683 case scSMaxExpr: 684 case scUMaxExpr: 685 case scSMinExpr: 686 case scUMinExpr: 687 case scSequentialUMinExpr: 688 case scAddRecExpr: 689 for (const auto *Op : S->operands()) { 690 push(Op); 691 if (Visitor.isDone()) 692 break; 693 } 694 continue; 695 case scCouldNotCompute: 696 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 697 } 698 llvm_unreachable("Unknown SCEV kind!"); 699 } 700 } 701 }; 702 703 /// Use SCEVTraversal to visit all nodes in the given expression tree. 704 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) { 705 SCEVTraversal<SV> T(Visitor); 706 T.visitAll(Root); 707 } 708 709 /// Return true if any node in \p Root satisfies the predicate \p Pred. 710 template <typename PredTy> 711 bool SCEVExprContains(const SCEV *Root, PredTy Pred) { 712 struct FindClosure { 713 bool Found = false; 714 PredTy Pred; 715 716 FindClosure(PredTy Pred) : Pred(Pred) {} 717 718 bool follow(const SCEV *S) { 719 if (!Pred(S)) 720 return true; 721 722 Found = true; 723 return false; 724 } 725 726 bool isDone() const { return Found; } 727 }; 728 729 FindClosure FC(Pred); 730 visitAll(Root, FC); 731 return FC.Found; 732 } 733 734 /// This visitor recursively visits a SCEV expression and re-writes it. 735 /// The result from each visit is cached, so it will return the same 736 /// SCEV for the same input. 737 template <typename SC> 738 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { 739 protected: 740 ScalarEvolution &SE; 741 // Memoize the result of each visit so that we only compute once for 742 // the same input SCEV. This is to avoid redundant computations when 743 // a SCEV is referenced by multiple SCEVs. Without memoization, this 744 // visit algorithm would have exponential time complexity in the worst 745 // case, causing the compiler to hang on certain tests. 746 DenseMap<const SCEV *, const SCEV *> RewriteResults; 747 748 public: 749 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} 750 751 const SCEV *visit(const SCEV *S) { 752 auto It = RewriteResults.find(S); 753 if (It != RewriteResults.end()) 754 return It->second; 755 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S); 756 auto Result = RewriteResults.try_emplace(S, Visited); 757 assert(Result.second && "Should insert a new entry"); 758 return Result.first->second; 759 } 760 761 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } 762 763 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { 764 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 765 return Operand == Expr->getOperand() 766 ? Expr 767 : SE.getPtrToIntExpr(Operand, Expr->getType()); 768 } 769 770 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { 771 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 772 return Operand == Expr->getOperand() 773 ? Expr 774 : SE.getTruncateExpr(Operand, Expr->getType()); 775 } 776 777 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 778 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 779 return Operand == Expr->getOperand() 780 ? Expr 781 : SE.getZeroExtendExpr(Operand, Expr->getType()); 782 } 783 784 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 785 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 786 return Operand == Expr->getOperand() 787 ? Expr 788 : SE.getSignExtendExpr(Operand, Expr->getType()); 789 } 790 791 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 792 SmallVector<const SCEV *, 2> Operands; 793 bool Changed = false; 794 for (const auto *Op : Expr->operands()) { 795 Operands.push_back(((SC *)this)->visit(Op)); 796 Changed |= Op != Operands.back(); 797 } 798 return !Changed ? Expr : SE.getAddExpr(Operands); 799 } 800 801 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 802 SmallVector<const SCEV *, 2> Operands; 803 bool Changed = false; 804 for (const auto *Op : Expr->operands()) { 805 Operands.push_back(((SC *)this)->visit(Op)); 806 Changed |= Op != Operands.back(); 807 } 808 return !Changed ? Expr : SE.getMulExpr(Operands); 809 } 810 811 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { 812 auto *LHS = ((SC *)this)->visit(Expr->getLHS()); 813 auto *RHS = ((SC *)this)->visit(Expr->getRHS()); 814 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); 815 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); 816 } 817 818 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 819 SmallVector<const SCEV *, 2> Operands; 820 bool Changed = false; 821 for (const auto *Op : Expr->operands()) { 822 Operands.push_back(((SC *)this)->visit(Op)); 823 Changed |= Op != Operands.back(); 824 } 825 return !Changed ? Expr 826 : SE.getAddRecExpr(Operands, Expr->getLoop(), 827 Expr->getNoWrapFlags()); 828 } 829 830 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { 831 SmallVector<const SCEV *, 2> Operands; 832 bool Changed = false; 833 for (const auto *Op : Expr->operands()) { 834 Operands.push_back(((SC *)this)->visit(Op)); 835 Changed |= Op != Operands.back(); 836 } 837 return !Changed ? Expr : SE.getSMaxExpr(Operands); 838 } 839 840 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { 841 SmallVector<const SCEV *, 2> Operands; 842 bool Changed = false; 843 for (const auto *Op : Expr->operands()) { 844 Operands.push_back(((SC *)this)->visit(Op)); 845 Changed |= Op != Operands.back(); 846 } 847 return !Changed ? Expr : SE.getUMaxExpr(Operands); 848 } 849 850 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { 851 SmallVector<const SCEV *, 2> Operands; 852 bool Changed = false; 853 for (const auto *Op : Expr->operands()) { 854 Operands.push_back(((SC *)this)->visit(Op)); 855 Changed |= Op != Operands.back(); 856 } 857 return !Changed ? Expr : SE.getSMinExpr(Operands); 858 } 859 860 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { 861 SmallVector<const SCEV *, 2> Operands; 862 bool Changed = false; 863 for (const auto *Op : Expr->operands()) { 864 Operands.push_back(((SC *)this)->visit(Op)); 865 Changed |= Op != Operands.back(); 866 } 867 return !Changed ? Expr : SE.getUMinExpr(Operands); 868 } 869 870 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { 871 SmallVector<const SCEV *, 2> Operands; 872 bool Changed = false; 873 for (const auto *Op : Expr->operands()) { 874 Operands.push_back(((SC *)this)->visit(Op)); 875 Changed |= Op != Operands.back(); 876 } 877 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true); 878 } 879 880 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } 881 882 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 883 return Expr; 884 } 885 }; 886 887 using ValueToValueMap = DenseMap<const Value *, Value *>; 888 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>; 889 890 /// The SCEVParameterRewriter takes a scalar evolution expression and updates 891 /// the SCEVUnknown components following the Map (Value -> SCEV). 892 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { 893 public: 894 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, 895 ValueToSCEVMapTy &Map) { 896 SCEVParameterRewriter Rewriter(SE, Map); 897 return Rewriter.visit(Scev); 898 } 899 900 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) 901 : SCEVRewriteVisitor(SE), Map(M) {} 902 903 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 904 auto I = Map.find(Expr->getValue()); 905 if (I == Map.end()) 906 return Expr; 907 return I->second; 908 } 909 910 private: 911 ValueToSCEVMapTy ⤅ 912 }; 913 914 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>; 915 916 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies 917 /// the Map (Loop -> SCEV) to all AddRecExprs. 918 class SCEVLoopAddRecRewriter 919 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { 920 public: 921 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) 922 : SCEVRewriteVisitor(SE), Map(M) {} 923 924 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, 925 ScalarEvolution &SE) { 926 SCEVLoopAddRecRewriter Rewriter(SE, Map); 927 return Rewriter.visit(Scev); 928 } 929 930 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 931 SmallVector<const SCEV *, 2> Operands; 932 for (const SCEV *Op : Expr->operands()) 933 Operands.push_back(visit(Op)); 934 935 const Loop *L = Expr->getLoop(); 936 if (0 == Map.count(L)) 937 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); 938 939 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE); 940 } 941 942 private: 943 LoopToScevMapT ⤅ 944 }; 945 946 } // end namespace llvm 947 948 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 949