1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the classes used to represent and build scalar expressions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/SmallPtrSet.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/iterator_range.h" 20 #include "llvm/Analysis/ScalarEvolution.h" 21 #include "llvm/IR/Constants.h" 22 #include "llvm/IR/ValueHandle.h" 23 #include "llvm/Support/Casting.h" 24 #include "llvm/Support/ErrorHandling.h" 25 #include <cassert> 26 #include <cstddef> 27 28 namespace llvm { 29 30 class APInt; 31 class Constant; 32 class ConstantInt; 33 class ConstantRange; 34 class Loop; 35 class Type; 36 class Value; 37 38 enum SCEVTypes : unsigned short { 39 // These should be ordered in terms of increasing complexity to make the 40 // folders simpler. 41 scConstant, 42 scTruncate, 43 scZeroExtend, 44 scSignExtend, 45 scAddExpr, 46 scMulExpr, 47 scUDivExpr, 48 scAddRecExpr, 49 scUMaxExpr, 50 scSMaxExpr, 51 scUMinExpr, 52 scSMinExpr, 53 scSequentialUMinExpr, 54 scPtrToInt, 55 scUnknown, 56 scCouldNotCompute 57 }; 58 59 /// This class represents a constant integer value. 60 class SCEVConstant : public SCEV { 61 friend class ScalarEvolution; 62 63 ConstantInt *V; 64 65 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) 66 : SCEV(ID, scConstant, 1), V(v) {} 67 68 public: 69 ConstantInt *getValue() const { return V; } 70 const APInt &getAPInt() const { return getValue()->getValue(); } 71 72 Type *getType() const { return V->getType(); } 73 74 /// Methods for support type inquiry through isa, cast, and dyn_cast: 75 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } 76 }; 77 78 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) { 79 APInt Size(16, 1); 80 for (auto *Arg : Args) 81 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize())); 82 return (unsigned short)Size.getZExtValue(); 83 } 84 85 /// This is the base class for unary cast operator classes. 86 class SCEVCastExpr : public SCEV { 87 protected: 88 std::array<const SCEV *, 1> Operands; 89 Type *Ty; 90 91 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, 92 Type *ty); 93 94 public: 95 const SCEV *getOperand() const { return Operands[0]; } 96 const SCEV *getOperand(unsigned i) const { 97 assert(i == 0 && "Operand index out of range!"); 98 return Operands[0]; 99 } 100 using op_iterator = std::array<const SCEV *, 1>::const_iterator; 101 using op_range = iterator_range<op_iterator>; 102 103 op_range operands() const { 104 return make_range(Operands.begin(), Operands.end()); 105 } 106 size_t getNumOperands() const { return 1; } 107 Type *getType() const { return Ty; } 108 109 /// Methods for support type inquiry through isa, cast, and dyn_cast: 110 static bool classof(const SCEV *S) { 111 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate || 112 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; 113 } 114 }; 115 116 /// This class represents a cast from a pointer to a pointer-sized integer 117 /// value. 118 class SCEVPtrToIntExpr : public SCEVCastExpr { 119 friend class ScalarEvolution; 120 121 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy); 122 123 public: 124 /// Methods for support type inquiry through isa, cast, and dyn_cast: 125 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; } 126 }; 127 128 /// This is the base class for unary integral cast operator classes. 129 class SCEVIntegralCastExpr : public SCEVCastExpr { 130 protected: 131 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, 132 const SCEV *op, Type *ty); 133 134 public: 135 /// Methods for support type inquiry through isa, cast, and dyn_cast: 136 static bool classof(const SCEV *S) { 137 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || 138 S->getSCEVType() == scSignExtend; 139 } 140 }; 141 142 /// This class represents a truncation of an integer value to a 143 /// smaller integer value. 144 class SCEVTruncateExpr : public SCEVIntegralCastExpr { 145 friend class ScalarEvolution; 146 147 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 148 149 public: 150 /// Methods for support type inquiry through isa, cast, and dyn_cast: 151 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } 152 }; 153 154 /// This class represents a zero extension of a small integer value 155 /// to a larger integer value. 156 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { 157 friend class ScalarEvolution; 158 159 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 160 161 public: 162 /// Methods for support type inquiry through isa, cast, and dyn_cast: 163 static bool classof(const SCEV *S) { 164 return S->getSCEVType() == scZeroExtend; 165 } 166 }; 167 168 /// This class represents a sign extension of a small integer value 169 /// to a larger integer value. 170 class SCEVSignExtendExpr : public SCEVIntegralCastExpr { 171 friend class ScalarEvolution; 172 173 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 174 175 public: 176 /// Methods for support type inquiry through isa, cast, and dyn_cast: 177 static bool classof(const SCEV *S) { 178 return S->getSCEVType() == scSignExtend; 179 } 180 }; 181 182 /// This node is a base class providing common functionality for 183 /// n'ary operators. 184 class SCEVNAryExpr : public SCEV { 185 protected: 186 // Since SCEVs are immutable, ScalarEvolution allocates operand 187 // arrays with its SCEVAllocator, so this class just needs a simple 188 // pointer rather than a more elaborate vector-like data structure. 189 // This also avoids the need for a non-trivial destructor. 190 const SCEV *const *Operands; 191 size_t NumOperands; 192 193 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 194 const SCEV *const *O, size_t N) 195 : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O), 196 NumOperands(N) {} 197 198 public: 199 size_t getNumOperands() const { return NumOperands; } 200 201 const SCEV *getOperand(unsigned i) const { 202 assert(i < NumOperands && "Operand index out of range!"); 203 return Operands[i]; 204 } 205 206 using op_iterator = const SCEV *const *; 207 using op_range = iterator_range<op_iterator>; 208 209 op_iterator op_begin() const { return Operands; } 210 op_iterator op_end() const { return Operands + NumOperands; } 211 op_range operands() const { return make_range(op_begin(), op_end()); } 212 213 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { 214 return (NoWrapFlags)(SubclassData & Mask); 215 } 216 217 bool hasNoUnsignedWrap() const { 218 return getNoWrapFlags(FlagNUW) != FlagAnyWrap; 219 } 220 221 bool hasNoSignedWrap() const { 222 return getNoWrapFlags(FlagNSW) != FlagAnyWrap; 223 } 224 225 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; } 226 227 /// Methods for support type inquiry through isa, cast, and dyn_cast: 228 static bool classof(const SCEV *S) { 229 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 230 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 231 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || 232 S->getSCEVType() == scSequentialUMinExpr || 233 S->getSCEVType() == scAddRecExpr; 234 } 235 }; 236 237 /// This node is the base class for n'ary commutative operators. 238 class SCEVCommutativeExpr : public SCEVNAryExpr { 239 protected: 240 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 241 const SCEV *const *O, size_t N) 242 : SCEVNAryExpr(ID, T, O, N) {} 243 244 public: 245 /// Methods for support type inquiry through isa, cast, and dyn_cast: 246 static bool classof(const SCEV *S) { 247 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 248 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 249 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; 250 } 251 252 /// Set flags for a non-recurrence without clearing previously set flags. 253 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 254 }; 255 256 /// This node represents an addition of some number of SCEVs. 257 class SCEVAddExpr : public SCEVCommutativeExpr { 258 friend class ScalarEvolution; 259 260 Type *Ty; 261 262 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 263 : SCEVCommutativeExpr(ID, scAddExpr, O, N) { 264 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) { 265 return Op->getType()->isPointerTy(); 266 }); 267 if (FirstPointerTypedOp != operands().end()) 268 Ty = (*FirstPointerTypedOp)->getType(); 269 else 270 Ty = getOperand(0)->getType(); 271 } 272 273 public: 274 Type *getType() const { return Ty; } 275 276 /// Methods for support type inquiry through isa, cast, and dyn_cast: 277 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } 278 }; 279 280 /// This node represents multiplication of some number of SCEVs. 281 class SCEVMulExpr : public SCEVCommutativeExpr { 282 friend class ScalarEvolution; 283 284 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 285 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} 286 287 public: 288 Type *getType() const { return getOperand(0)->getType(); } 289 290 /// Methods for support type inquiry through isa, cast, and dyn_cast: 291 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } 292 }; 293 294 /// This class represents a binary unsigned division operation. 295 class SCEVUDivExpr : public SCEV { 296 friend class ScalarEvolution; 297 298 std::array<const SCEV *, 2> Operands; 299 300 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) 301 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) { 302 Operands[0] = lhs; 303 Operands[1] = rhs; 304 } 305 306 public: 307 const SCEV *getLHS() const { return Operands[0]; } 308 const SCEV *getRHS() const { return Operands[1]; } 309 size_t getNumOperands() const { return 2; } 310 const SCEV *getOperand(unsigned i) const { 311 assert((i == 0 || i == 1) && "Operand index out of range!"); 312 return i == 0 ? getLHS() : getRHS(); 313 } 314 315 using op_iterator = std::array<const SCEV *, 2>::const_iterator; 316 using op_range = iterator_range<op_iterator>; 317 op_range operands() const { 318 return make_range(Operands.begin(), Operands.end()); 319 } 320 321 Type *getType() const { 322 // In most cases the types of LHS and RHS will be the same, but in some 323 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't 324 // depend on the type for correctness, but handling types carefully can 325 // avoid extra casts in the SCEVExpander. The LHS is more likely to be 326 // a pointer type than the RHS, so use the RHS' type here. 327 return getRHS()->getType(); 328 } 329 330 /// Methods for support type inquiry through isa, cast, and dyn_cast: 331 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } 332 }; 333 334 /// This node represents a polynomial recurrence on the trip count 335 /// of the specified loop. This is the primary focus of the 336 /// ScalarEvolution framework; all the other SCEV subclasses are 337 /// mostly just supporting infrastructure to allow SCEVAddRecExpr 338 /// expressions to be created and analyzed. 339 /// 340 /// All operands of an AddRec are required to be loop invariant. 341 /// 342 class SCEVAddRecExpr : public SCEVNAryExpr { 343 friend class ScalarEvolution; 344 345 const Loop *L; 346 347 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N, 348 const Loop *l) 349 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} 350 351 public: 352 Type *getType() const { return getStart()->getType(); } 353 const SCEV *getStart() const { return Operands[0]; } 354 const Loop *getLoop() const { return L; } 355 356 /// Constructs and returns the recurrence indicating how much this 357 /// expression steps by. If this is a polynomial of degree N, it 358 /// returns a chrec of degree N-1. We cannot determine whether 359 /// the step recurrence has self-wraparound. 360 const SCEV *getStepRecurrence(ScalarEvolution &SE) const { 361 if (isAffine()) 362 return getOperand(1); 363 return SE.getAddRecExpr( 364 SmallVector<const SCEV *, 3>(op_begin() + 1, op_end()), getLoop(), 365 FlagAnyWrap); 366 } 367 368 /// Return true if this represents an expression A + B*x where A 369 /// and B are loop invariant values. 370 bool isAffine() const { 371 // We know that the start value is invariant. This expression is thus 372 // affine iff the step is also invariant. 373 return getNumOperands() == 2; 374 } 375 376 /// Return true if this represents an expression A + B*x + C*x^2 377 /// where A, B and C are loop invariant values. This corresponds 378 /// to an addrec of the form {L,+,M,+,N} 379 bool isQuadratic() const { return getNumOperands() == 3; } 380 381 /// Set flags for a recurrence without clearing any previously set flags. 382 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here 383 /// to make it easier to propagate flags. 384 void setNoWrapFlags(NoWrapFlags Flags) { 385 if (Flags & (FlagNUW | FlagNSW)) 386 Flags = ScalarEvolution::setFlags(Flags, FlagNW); 387 SubclassData |= Flags; 388 } 389 390 /// Return the value of this chain of recurrences at the specified 391 /// iteration number. 392 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; 393 394 /// Return the value of this chain of recurrences at the specified iteration 395 /// number. Takes an explicit list of operands to represent an AddRec. 396 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands, 397 const SCEV *It, ScalarEvolution &SE); 398 399 /// Return the number of iterations of this loop that produce 400 /// values in the specified constant range. Another way of 401 /// looking at this is that it returns the first iteration number 402 /// where the value is not in the condition, thus computing the 403 /// exit count. If the iteration count can't be computed, an 404 /// instance of SCEVCouldNotCompute is returned. 405 const SCEV *getNumIterationsInRange(const ConstantRange &Range, 406 ScalarEvolution &SE) const; 407 408 /// Return an expression representing the value of this expression 409 /// one iteration of the loop ahead. 410 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const; 411 412 /// Methods for support type inquiry through isa, cast, and dyn_cast: 413 static bool classof(const SCEV *S) { 414 return S->getSCEVType() == scAddRecExpr; 415 } 416 }; 417 418 /// This node is the base class min/max selections. 419 class SCEVMinMaxExpr : public SCEVCommutativeExpr { 420 friend class ScalarEvolution; 421 422 static bool isMinMaxType(enum SCEVTypes T) { 423 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr || 424 T == scUMinExpr; 425 } 426 427 protected: 428 /// Note: Constructing subclasses via this constructor is allowed 429 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 430 const SCEV *const *O, size_t N) 431 : SCEVCommutativeExpr(ID, T, O, N) { 432 assert(isMinMaxType(T)); 433 // Min and max never overflow 434 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 435 } 436 437 public: 438 Type *getType() const { return getOperand(0)->getType(); } 439 440 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); } 441 442 static enum SCEVTypes negate(enum SCEVTypes T) { 443 switch (T) { 444 case scSMaxExpr: 445 return scSMinExpr; 446 case scSMinExpr: 447 return scSMaxExpr; 448 case scUMaxExpr: 449 return scUMinExpr; 450 case scUMinExpr: 451 return scUMaxExpr; 452 default: 453 llvm_unreachable("Not a min or max SCEV type!"); 454 } 455 } 456 }; 457 458 /// This class represents a signed maximum selection. 459 class SCEVSMaxExpr : public SCEVMinMaxExpr { 460 friend class ScalarEvolution; 461 462 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 463 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {} 464 465 public: 466 /// Methods for support type inquiry through isa, cast, and dyn_cast: 467 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } 468 }; 469 470 /// This class represents an unsigned maximum selection. 471 class SCEVUMaxExpr : public SCEVMinMaxExpr { 472 friend class ScalarEvolution; 473 474 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 475 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {} 476 477 public: 478 /// Methods for support type inquiry through isa, cast, and dyn_cast: 479 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } 480 }; 481 482 /// This class represents a signed minimum selection. 483 class SCEVSMinExpr : public SCEVMinMaxExpr { 484 friend class ScalarEvolution; 485 486 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 487 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {} 488 489 public: 490 /// Methods for support type inquiry through isa, cast, and dyn_cast: 491 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; } 492 }; 493 494 /// This class represents an unsigned minimum selection. 495 class SCEVUMinExpr : public SCEVMinMaxExpr { 496 friend class ScalarEvolution; 497 498 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 499 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {} 500 501 public: 502 /// Methods for support type inquiry through isa, cast, and dyn_cast: 503 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; } 504 }; 505 506 /// This node is the base class for sequential/in-order min/max selections. 507 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they 508 /// are early-returning upon reaching saturation point. 509 /// I.e. given `0 umin_seq poison`, the result will be `0`, 510 /// while the result of `0 umin poison` is `poison`. 511 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { 512 friend class ScalarEvolution; 513 514 static bool isSequentialMinMaxType(enum SCEVTypes T) { 515 return T == scSequentialUMinExpr; 516 } 517 518 /// Set flags for a non-recurrence without clearing previously set flags. 519 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 520 521 protected: 522 /// Note: Constructing subclasses via this constructor is allowed 523 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 524 const SCEV *const *O, size_t N) 525 : SCEVNAryExpr(ID, T, O, N) { 526 assert(isSequentialMinMaxType(T)); 527 // Min and max never overflow 528 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 529 } 530 531 public: 532 Type *getType() const { return getOperand(0)->getType(); } 533 534 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) { 535 assert(isSequentialMinMaxType(Ty)); 536 switch (Ty) { 537 case scSequentialUMinExpr: 538 return scUMinExpr; 539 default: 540 llvm_unreachable("Not a sequential min/max type."); 541 } 542 } 543 544 SCEVTypes getEquivalentNonSequentialSCEVType() const { 545 return getEquivalentNonSequentialSCEVType(getSCEVType()); 546 } 547 548 static bool classof(const SCEV *S) { 549 return isSequentialMinMaxType(S->getSCEVType()); 550 } 551 }; 552 553 /// This class represents a sequential/in-order unsigned minimum selection. 554 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { 555 friend class ScalarEvolution; 556 557 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, 558 size_t N) 559 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} 560 561 public: 562 /// Methods for support type inquiry through isa, cast, and dyn_cast: 563 static bool classof(const SCEV *S) { 564 return S->getSCEVType() == scSequentialUMinExpr; 565 } 566 }; 567 568 /// This means that we are dealing with an entirely unknown SCEV 569 /// value, and only represent it as its LLVM Value. This is the 570 /// "bottom" value for the analysis. 571 class SCEVUnknown final : public SCEV, private CallbackVH { 572 friend class ScalarEvolution; 573 574 /// The parent ScalarEvolution value. This is used to update the 575 /// parent's maps when the value associated with a SCEVUnknown is 576 /// deleted or RAUW'd. 577 ScalarEvolution *SE; 578 579 /// The next pointer in the linked list of all SCEVUnknown 580 /// instances owned by a ScalarEvolution. 581 SCEVUnknown *Next; 582 583 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se, 584 SCEVUnknown *next) 585 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {} 586 587 // Implement CallbackVH. 588 void deleted() override; 589 void allUsesReplacedWith(Value *New) override; 590 591 public: 592 Value *getValue() const { return getValPtr(); } 593 594 /// @{ 595 /// Test whether this is a special constant representing a type 596 /// size, alignment, or field offset in a target-independent 597 /// manner, and hasn't happened to have been folded with other 598 /// operations into something unrecognizable. This is mainly only 599 /// useful for pretty-printing and other situations where it isn't 600 /// absolutely required for these to succeed. 601 bool isSizeOf(Type *&AllocTy) const; 602 bool isAlignOf(Type *&AllocTy) const; 603 bool isOffsetOf(Type *&STy, Constant *&FieldNo) const; 604 /// @} 605 606 Type *getType() const { return getValPtr()->getType(); } 607 608 /// Methods for support type inquiry through isa, cast, and dyn_cast: 609 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } 610 }; 611 612 /// This class defines a simple visitor class that may be used for 613 /// various SCEV analysis purposes. 614 template <typename SC, typename RetVal = void> struct SCEVVisitor { 615 RetVal visit(const SCEV *S) { 616 switch (S->getSCEVType()) { 617 case scConstant: 618 return ((SC *)this)->visitConstant((const SCEVConstant *)S); 619 case scPtrToInt: 620 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); 621 case scTruncate: 622 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S); 623 case scZeroExtend: 624 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S); 625 case scSignExtend: 626 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S); 627 case scAddExpr: 628 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S); 629 case scMulExpr: 630 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S); 631 case scUDivExpr: 632 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S); 633 case scAddRecExpr: 634 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S); 635 case scSMaxExpr: 636 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S); 637 case scUMaxExpr: 638 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S); 639 case scSMinExpr: 640 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); 641 case scUMinExpr: 642 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); 643 case scSequentialUMinExpr: 644 return ((SC *)this) 645 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); 646 case scUnknown: 647 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); 648 case scCouldNotCompute: 649 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S); 650 } 651 llvm_unreachable("Unknown SCEV kind!"); 652 } 653 654 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { 655 llvm_unreachable("Invalid use of SCEVCouldNotCompute!"); 656 } 657 }; 658 659 /// Visit all nodes in the expression tree using worklist traversal. 660 /// 661 /// Visitor implements: 662 /// // return true to follow this node. 663 /// bool follow(const SCEV *S); 664 /// // return true to terminate the search. 665 /// bool isDone(); 666 template <typename SV> class SCEVTraversal { 667 SV &Visitor; 668 SmallVector<const SCEV *, 8> Worklist; 669 SmallPtrSet<const SCEV *, 8> Visited; 670 671 void push(const SCEV *S) { 672 if (Visited.insert(S).second && Visitor.follow(S)) 673 Worklist.push_back(S); 674 } 675 676 public: 677 SCEVTraversal(SV &V) : Visitor(V) {} 678 679 void visitAll(const SCEV *Root) { 680 push(Root); 681 while (!Worklist.empty() && !Visitor.isDone()) { 682 const SCEV *S = Worklist.pop_back_val(); 683 684 switch (S->getSCEVType()) { 685 case scConstant: 686 case scUnknown: 687 continue; 688 case scPtrToInt: 689 case scTruncate: 690 case scZeroExtend: 691 case scSignExtend: 692 push(cast<SCEVCastExpr>(S)->getOperand()); 693 continue; 694 case scAddExpr: 695 case scMulExpr: 696 case scSMaxExpr: 697 case scUMaxExpr: 698 case scSMinExpr: 699 case scUMinExpr: 700 case scSequentialUMinExpr: 701 case scAddRecExpr: 702 for (const auto *Op : cast<SCEVNAryExpr>(S)->operands()) { 703 push(Op); 704 if (Visitor.isDone()) 705 break; 706 } 707 continue; 708 case scUDivExpr: { 709 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 710 push(UDiv->getLHS()); 711 push(UDiv->getRHS()); 712 continue; 713 } 714 case scCouldNotCompute: 715 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 716 } 717 llvm_unreachable("Unknown SCEV kind!"); 718 } 719 } 720 }; 721 722 /// Use SCEVTraversal to visit all nodes in the given expression tree. 723 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) { 724 SCEVTraversal<SV> T(Visitor); 725 T.visitAll(Root); 726 } 727 728 /// Return true if any node in \p Root satisfies the predicate \p Pred. 729 template <typename PredTy> 730 bool SCEVExprContains(const SCEV *Root, PredTy Pred) { 731 struct FindClosure { 732 bool Found = false; 733 PredTy Pred; 734 735 FindClosure(PredTy Pred) : Pred(Pred) {} 736 737 bool follow(const SCEV *S) { 738 if (!Pred(S)) 739 return true; 740 741 Found = true; 742 return false; 743 } 744 745 bool isDone() const { return Found; } 746 }; 747 748 FindClosure FC(Pred); 749 visitAll(Root, FC); 750 return FC.Found; 751 } 752 753 /// This visitor recursively visits a SCEV expression and re-writes it. 754 /// The result from each visit is cached, so it will return the same 755 /// SCEV for the same input. 756 template <typename SC> 757 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { 758 protected: 759 ScalarEvolution &SE; 760 // Memoize the result of each visit so that we only compute once for 761 // the same input SCEV. This is to avoid redundant computations when 762 // a SCEV is referenced by multiple SCEVs. Without memoization, this 763 // visit algorithm would have exponential time complexity in the worst 764 // case, causing the compiler to hang on certain tests. 765 DenseMap<const SCEV *, const SCEV *> RewriteResults; 766 767 public: 768 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} 769 770 const SCEV *visit(const SCEV *S) { 771 auto It = RewriteResults.find(S); 772 if (It != RewriteResults.end()) 773 return It->second; 774 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S); 775 auto Result = RewriteResults.try_emplace(S, Visited); 776 assert(Result.second && "Should insert a new entry"); 777 return Result.first->second; 778 } 779 780 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } 781 782 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { 783 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 784 return Operand == Expr->getOperand() 785 ? Expr 786 : SE.getPtrToIntExpr(Operand, Expr->getType()); 787 } 788 789 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { 790 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 791 return Operand == Expr->getOperand() 792 ? Expr 793 : SE.getTruncateExpr(Operand, Expr->getType()); 794 } 795 796 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 797 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 798 return Operand == Expr->getOperand() 799 ? Expr 800 : SE.getZeroExtendExpr(Operand, Expr->getType()); 801 } 802 803 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 804 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 805 return Operand == Expr->getOperand() 806 ? Expr 807 : SE.getSignExtendExpr(Operand, Expr->getType()); 808 } 809 810 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 811 SmallVector<const SCEV *, 2> Operands; 812 bool Changed = false; 813 for (auto *Op : Expr->operands()) { 814 Operands.push_back(((SC *)this)->visit(Op)); 815 Changed |= Op != Operands.back(); 816 } 817 return !Changed ? Expr : SE.getAddExpr(Operands); 818 } 819 820 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 821 SmallVector<const SCEV *, 2> Operands; 822 bool Changed = false; 823 for (auto *Op : Expr->operands()) { 824 Operands.push_back(((SC *)this)->visit(Op)); 825 Changed |= Op != Operands.back(); 826 } 827 return !Changed ? Expr : SE.getMulExpr(Operands); 828 } 829 830 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { 831 auto *LHS = ((SC *)this)->visit(Expr->getLHS()); 832 auto *RHS = ((SC *)this)->visit(Expr->getRHS()); 833 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); 834 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); 835 } 836 837 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 838 SmallVector<const SCEV *, 2> Operands; 839 bool Changed = false; 840 for (auto *Op : Expr->operands()) { 841 Operands.push_back(((SC *)this)->visit(Op)); 842 Changed |= Op != Operands.back(); 843 } 844 return !Changed ? Expr 845 : SE.getAddRecExpr(Operands, Expr->getLoop(), 846 Expr->getNoWrapFlags()); 847 } 848 849 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { 850 SmallVector<const SCEV *, 2> Operands; 851 bool Changed = false; 852 for (auto *Op : Expr->operands()) { 853 Operands.push_back(((SC *)this)->visit(Op)); 854 Changed |= Op != Operands.back(); 855 } 856 return !Changed ? Expr : SE.getSMaxExpr(Operands); 857 } 858 859 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { 860 SmallVector<const SCEV *, 2> Operands; 861 bool Changed = false; 862 for (auto *Op : Expr->operands()) { 863 Operands.push_back(((SC *)this)->visit(Op)); 864 Changed |= Op != Operands.back(); 865 } 866 return !Changed ? Expr : SE.getUMaxExpr(Operands); 867 } 868 869 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { 870 SmallVector<const SCEV *, 2> Operands; 871 bool Changed = false; 872 for (auto *Op : Expr->operands()) { 873 Operands.push_back(((SC *)this)->visit(Op)); 874 Changed |= Op != Operands.back(); 875 } 876 return !Changed ? Expr : SE.getSMinExpr(Operands); 877 } 878 879 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { 880 SmallVector<const SCEV *, 2> Operands; 881 bool Changed = false; 882 for (auto *Op : Expr->operands()) { 883 Operands.push_back(((SC *)this)->visit(Op)); 884 Changed |= Op != Operands.back(); 885 } 886 return !Changed ? Expr : SE.getUMinExpr(Operands); 887 } 888 889 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { 890 SmallVector<const SCEV *, 2> Operands; 891 bool Changed = false; 892 for (auto *Op : Expr->operands()) { 893 Operands.push_back(((SC *)this)->visit(Op)); 894 Changed |= Op != Operands.back(); 895 } 896 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true); 897 } 898 899 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } 900 901 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 902 return Expr; 903 } 904 }; 905 906 using ValueToValueMap = DenseMap<const Value *, Value *>; 907 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>; 908 909 /// The SCEVParameterRewriter takes a scalar evolution expression and updates 910 /// the SCEVUnknown components following the Map (Value -> SCEV). 911 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { 912 public: 913 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, 914 ValueToSCEVMapTy &Map) { 915 SCEVParameterRewriter Rewriter(SE, Map); 916 return Rewriter.visit(Scev); 917 } 918 919 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) 920 : SCEVRewriteVisitor(SE), Map(M) {} 921 922 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 923 auto I = Map.find(Expr->getValue()); 924 if (I == Map.end()) 925 return Expr; 926 return I->second; 927 } 928 929 private: 930 ValueToSCEVMapTy ⤅ 931 }; 932 933 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>; 934 935 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies 936 /// the Map (Loop -> SCEV) to all AddRecExprs. 937 class SCEVLoopAddRecRewriter 938 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { 939 public: 940 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) 941 : SCEVRewriteVisitor(SE), Map(M) {} 942 943 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, 944 ScalarEvolution &SE) { 945 SCEVLoopAddRecRewriter Rewriter(SE, Map); 946 return Rewriter.visit(Scev); 947 } 948 949 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 950 SmallVector<const SCEV *, 2> Operands; 951 for (const SCEV *Op : Expr->operands()) 952 Operands.push_back(visit(Op)); 953 954 const Loop *L = Expr->getLoop(); 955 if (0 == Map.count(L)) 956 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); 957 958 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE); 959 } 960 961 private: 962 LoopToScevMapT ⤅ 963 }; 964 965 } // end namespace llvm 966 967 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 968