1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopDetection.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8 
9 using namespace llvm;
10 using namespace polly;
11 
12 #define DEBUG_TYPE "polly-scev-validator"
13 
14 namespace SCEVType {
15 /// The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22   // An integer value.
23   INT,
24 
25   // An expression that is constant during the execution of the Scop,
26   // but that may depend on parameters unknown at compile time.
27   PARAM,
28 
29   // An expression that may change during the execution of the SCoP.
30   IV,
31 
32   // An invalid expression.
33   INVALID
34 };
35 } // namespace SCEVType
36 
37 /// The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39   /// The type of the expression
40   SCEVType::TYPE Type;
41 
42   /// The set of Parameters in the expression.
43   ParameterSetTy Parameters;
44 
45 public:
46   /// The copy constructor
ValidatorResult(const ValidatorResult & Source)47   ValidatorResult(const ValidatorResult &Source) {
48     Type = Source.Type;
49     Parameters = Source.Parameters;
50   }
51 
52   /// Construct a result with a certain type and no parameters.
ValidatorResult(SCEVType::TYPE Type)53   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55   }
56 
57   /// Construct a result with a certain type and a single parameter.
ValidatorResult(SCEVType::TYPE Type,const SCEV * Expr)58   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59     Parameters.insert(Expr);
60   }
61 
62   /// Get the type of the ValidatorResult.
getType()63   SCEVType::TYPE getType() { return Type; }
64 
65   /// Is the analyzed SCEV constant during the execution of the SCoP.
isConstant()66   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67 
68   /// Is the analyzed SCEV valid.
isValid()69   bool isValid() { return Type != SCEVType::INVALID; }
70 
71   /// Is the analyzed SCEV of Type IV.
isIV()72   bool isIV() { return Type == SCEVType::IV; }
73 
74   /// Is the analyzed SCEV of Type INT.
isINT()75   bool isINT() { return Type == SCEVType::INT; }
76 
77   /// Is the analyzed SCEV of Type PARAM.
isPARAM()78   bool isPARAM() { return Type == SCEVType::PARAM; }
79 
80   /// Get the parameters of this validator result.
getParameters()81   const ParameterSetTy &getParameters() { return Parameters; }
82 
83   /// Add the parameters of Source to this result.
addParamsFrom(const ValidatorResult & Source)84   void addParamsFrom(const ValidatorResult &Source) {
85     Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
86   }
87 
88   /// Merge a result.
89   ///
90   /// This means to merge the parameters and to set the Type to the most
91   /// specific Type that matches both.
merge(const ValidatorResult & ToMerge)92   void merge(const ValidatorResult &ToMerge) {
93     Type = std::max(Type, ToMerge.Type);
94     addParamsFrom(ToMerge);
95   }
96 
print(raw_ostream & OS)97   void print(raw_ostream &OS) {
98     switch (Type) {
99     case SCEVType::INT:
100       OS << "SCEVType::INT";
101       break;
102     case SCEVType::PARAM:
103       OS << "SCEVType::PARAM";
104       break;
105     case SCEVType::IV:
106       OS << "SCEVType::IV";
107       break;
108     case SCEVType::INVALID:
109       OS << "SCEVType::INVALID";
110       break;
111     }
112   }
113 };
114 
operator <<(raw_ostream & OS,class ValidatorResult & VR)115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
116   VR.print(OS);
117   return OS;
118 }
119 
isConstCall(llvm::CallInst * Call)120 bool polly::isConstCall(llvm::CallInst *Call) {
121   if (Call->mayReadOrWriteMemory())
122     return false;
123 
124   for (auto &Operand : Call->arg_operands())
125     if (!isa<ConstantInt>(&Operand))
126       return false;
127 
128   return true;
129 }
130 
131 /// Check if a SCEV is valid in a SCoP.
132 struct SCEVValidator
133     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
134 private:
135   const Region *R;
136   Loop *Scope;
137   ScalarEvolution &SE;
138   InvariantLoadsSetTy *ILS;
139 
140 public:
SCEVValidatorSCEVValidator141   SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
142                 InvariantLoadsSetTy *ILS)
143       : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
144 
visitConstantSCEVValidator145   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
146     return ValidatorResult(SCEVType::INT);
147   }
148 
visitZeroExtendOrTruncateExprSCEVValidator149   class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
150                                                       const SCEV *Operand) {
151     ValidatorResult Op = visit(Operand);
152     auto Type = Op.getType();
153 
154     // If unsigned operations are allowed return the operand, otherwise
155     // check if we can model the expression without unsigned assumptions.
156     if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
157       return Op;
158 
159     if (Type == SCEVType::IV)
160       return ValidatorResult(SCEVType::INVALID);
161     return ValidatorResult(SCEVType::PARAM, Expr);
162   }
163 
visitPtrToIntExprSCEVValidator164   class ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
165     return visit(Expr->getOperand());
166   }
167 
visitTruncateExprSCEVValidator168   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
169     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
170   }
171 
visitZeroExtendExprSCEVValidator172   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
173     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
174   }
175 
visitSignExtendExprSCEVValidator176   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
177     return visit(Expr->getOperand());
178   }
179 
visitAddExprSCEVValidator180   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
181     ValidatorResult Return(SCEVType::INT);
182 
183     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
184       ValidatorResult Op = visit(Expr->getOperand(i));
185       Return.merge(Op);
186 
187       // Early exit.
188       if (!Return.isValid())
189         break;
190     }
191 
192     return Return;
193   }
194 
visitMulExprSCEVValidator195   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
196     ValidatorResult Return(SCEVType::INT);
197 
198     bool HasMultipleParams = false;
199 
200     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
201       ValidatorResult Op = visit(Expr->getOperand(i));
202 
203       if (Op.isINT())
204         continue;
205 
206       if (Op.isPARAM() && Return.isPARAM()) {
207         HasMultipleParams = true;
208         continue;
209       }
210 
211       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
212         LLVM_DEBUG(
213             dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
214                    << "\tExpr: " << *Expr << "\n"
215                    << "\tPrevious expression type: " << Return << "\n"
216                    << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
217                    << "\n");
218 
219         return ValidatorResult(SCEVType::INVALID);
220       }
221 
222       Return.merge(Op);
223     }
224 
225     if (HasMultipleParams && Return.isValid())
226       return ValidatorResult(SCEVType::PARAM, Expr);
227 
228     return Return;
229   }
230 
visitAddRecExprSCEVValidator231   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
232     if (!Expr->isAffine()) {
233       LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
234       return ValidatorResult(SCEVType::INVALID);
235     }
236 
237     ValidatorResult Start = visit(Expr->getStart());
238     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
239 
240     if (!Start.isValid())
241       return Start;
242 
243     if (!Recurrence.isValid())
244       return Recurrence;
245 
246     auto *L = Expr->getLoop();
247     if (R->contains(L) && (!Scope || !L->contains(Scope))) {
248       LLVM_DEBUG(
249           dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
250                     "non-affine subregion or has a non-synthesizable exit "
251                     "value.");
252       return ValidatorResult(SCEVType::INVALID);
253     }
254 
255     if (R->contains(L)) {
256       if (Recurrence.isINT()) {
257         ValidatorResult Result(SCEVType::IV);
258         Result.addParamsFrom(Start);
259         return Result;
260       }
261 
262       LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
263                            "recurrence part");
264       return ValidatorResult(SCEVType::INVALID);
265     }
266 
267     assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
268 
269     // Directly generate ValidatorResult for Expr if 'start' is zero.
270     if (Expr->getStart()->isZero())
271       return ValidatorResult(SCEVType::PARAM, Expr);
272 
273     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
274     // if 'start' is not zero.
275     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
276         SE.getConstant(Expr->getStart()->getType(), 0),
277         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
278 
279     ValidatorResult ZeroStartResult =
280         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
281     ZeroStartResult.addParamsFrom(Start);
282 
283     return ZeroStartResult;
284   }
285 
visitSMaxExprSCEVValidator286   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
287     ValidatorResult Return(SCEVType::INT);
288 
289     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
290       ValidatorResult Op = visit(Expr->getOperand(i));
291 
292       if (!Op.isValid())
293         return Op;
294 
295       Return.merge(Op);
296     }
297 
298     return Return;
299   }
300 
visitSMinExprSCEVValidator301   class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
302     ValidatorResult Return(SCEVType::INT);
303 
304     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
305       ValidatorResult Op = visit(Expr->getOperand(i));
306 
307       if (!Op.isValid())
308         return Op;
309 
310       Return.merge(Op);
311     }
312 
313     return Return;
314   }
315 
visitUMaxExprSCEVValidator316   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
317     // We do not support unsigned max operations. If 'Expr' is constant during
318     // Scop execution we treat this as a parameter, otherwise we bail out.
319     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
320       ValidatorResult Op = visit(Expr->getOperand(i));
321 
322       if (!Op.isConstant()) {
323         LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
324         return ValidatorResult(SCEVType::INVALID);
325       }
326     }
327 
328     return ValidatorResult(SCEVType::PARAM, Expr);
329   }
330 
visitUMinExprSCEVValidator331   class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
332     // We do not support unsigned min operations. If 'Expr' is constant during
333     // Scop execution we treat this as a parameter, otherwise we bail out.
334     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
335       ValidatorResult Op = visit(Expr->getOperand(i));
336 
337       if (!Op.isConstant()) {
338         LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
339         return ValidatorResult(SCEVType::INVALID);
340       }
341     }
342 
343     return ValidatorResult(SCEVType::PARAM, Expr);
344   }
345 
visitGenericInstSCEVValidator346   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
347     if (R->contains(I)) {
348       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
349                            "within the region\n");
350       return ValidatorResult(SCEVType::INVALID);
351     }
352 
353     return ValidatorResult(SCEVType::PARAM, S);
354   }
355 
visitCallInstructionSCEVValidator356   ValidatorResult visitCallInstruction(Instruction *I, const SCEV *S) {
357     assert(I->getOpcode() == Instruction::Call && "Call instruction expected");
358 
359     if (R->contains(I)) {
360       auto Call = cast<CallInst>(I);
361 
362       if (!isConstCall(Call))
363         return ValidatorResult(SCEVType::INVALID, S);
364     }
365     return ValidatorResult(SCEVType::PARAM, S);
366   }
367 
visitLoadInstructionSCEVValidator368   ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
369     if (R->contains(I) && ILS) {
370       ILS->insert(cast<LoadInst>(I));
371       return ValidatorResult(SCEVType::PARAM, S);
372     }
373 
374     return visitGenericInst(I, S);
375   }
376 
visitDivisionSCEVValidator377   ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
378                                 const SCEV *DivExpr,
379                                 Instruction *SDiv = nullptr) {
380 
381     // First check if we might be able to model the division, thus if the
382     // divisor is constant. If so, check the dividend, otherwise check if
383     // the whole division can be seen as a parameter.
384     if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
385       return visit(Dividend);
386 
387     // For signed divisions use the SDiv instruction to check for a parameter
388     // division, for unsigned divisions check the operands.
389     if (SDiv)
390       return visitGenericInst(SDiv, DivExpr);
391 
392     ValidatorResult LHS = visit(Dividend);
393     ValidatorResult RHS = visit(Divisor);
394     if (LHS.isConstant() && RHS.isConstant())
395       return ValidatorResult(SCEVType::PARAM, DivExpr);
396 
397     LLVM_DEBUG(
398         dbgs() << "INVALID: unsigned division of non-constant expressions");
399     return ValidatorResult(SCEVType::INVALID);
400   }
401 
visitUDivExprSCEVValidator402   ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
403     if (!PollyAllowUnsignedOperations)
404       return ValidatorResult(SCEVType::INVALID);
405 
406     auto *Dividend = Expr->getLHS();
407     auto *Divisor = Expr->getRHS();
408     return visitDivision(Dividend, Divisor, Expr);
409   }
410 
visitSDivInstructionSCEVValidator411   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
412     assert(SDiv->getOpcode() == Instruction::SDiv &&
413            "Assumed SDiv instruction!");
414 
415     auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
416     auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
417     return visitDivision(Dividend, Divisor, Expr, SDiv);
418   }
419 
visitSRemInstructionSCEVValidator420   ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
421     assert(SRem->getOpcode() == Instruction::SRem &&
422            "Assumed SRem instruction!");
423 
424     auto *Divisor = SRem->getOperand(1);
425     auto *CI = dyn_cast<ConstantInt>(Divisor);
426     if (!CI || CI->isZeroValue())
427       return visitGenericInst(SRem, S);
428 
429     auto *Dividend = SRem->getOperand(0);
430     auto *DividendSCEV = SE.getSCEV(Dividend);
431     return visit(DividendSCEV);
432   }
433 
visitUnknownSCEVValidator434   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
435     Value *V = Expr->getValue();
436 
437     if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
438       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
439       return ValidatorResult(SCEVType::INVALID);
440     }
441 
442     if (isa<UndefValue>(V)) {
443       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
444       return ValidatorResult(SCEVType::INVALID);
445     }
446 
447     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
448       switch (I->getOpcode()) {
449       case Instruction::IntToPtr:
450         return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
451       case Instruction::Load:
452         return visitLoadInstruction(I, Expr);
453       case Instruction::SDiv:
454         return visitSDivInstruction(I, Expr);
455       case Instruction::SRem:
456         return visitSRemInstruction(I, Expr);
457       case Instruction::Call:
458         return visitCallInstruction(I, Expr);
459       default:
460         return visitGenericInst(I, Expr);
461       }
462     }
463 
464     if (Expr->getType()->isPointerTy()) {
465       if (isa<ConstantPointerNull>(V))
466         return ValidatorResult(SCEVType::INT); // "int"
467     }
468 
469     return ValidatorResult(SCEVType::PARAM, Expr);
470   }
471 };
472 
473 class SCEVHasIVParams {
474   bool HasIVParams = false;
475 
476 public:
SCEVHasIVParams()477   SCEVHasIVParams() {}
478 
follow(const SCEV * S)479   bool follow(const SCEV *S) {
480     const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
481     if (!Unknown)
482       return true;
483 
484     CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
485 
486     if (!Call)
487       return true;
488 
489     if (isConstCall(Call)) {
490       HasIVParams = true;
491       return false;
492     }
493 
494     return true;
495   }
496 
isDone()497   bool isDone() { return HasIVParams; }
hasIVParams()498   bool hasIVParams() { return HasIVParams; }
499 };
500 
501 /// Check whether a SCEV refers to an SSA name defined inside a region.
502 class SCEVInRegionDependences {
503   const Region *R;
504   Loop *Scope;
505   const InvariantLoadsSetTy &ILS;
506   bool AllowLoops;
507   bool HasInRegionDeps = false;
508 
509 public:
SCEVInRegionDependences(const Region * R,Loop * Scope,bool AllowLoops,const InvariantLoadsSetTy & ILS)510   SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
511                           const InvariantLoadsSetTy &ILS)
512       : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
513 
follow(const SCEV * S)514   bool follow(const SCEV *S) {
515     if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
516       Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
517 
518       CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
519 
520       if (Call && isConstCall(Call))
521         return false;
522 
523       if (Inst) {
524         // When we invariant load hoist a load, we first make sure that there
525         // can be no dependences created by it in the Scop region. So, we should
526         // not consider scalar dependences to `LoadInst`s that are invariant
527         // load hoisted.
528         //
529         // If this check is not present, then we create data dependences which
530         // are strictly not necessary by tracking the invariant load as a
531         // scalar.
532         LoadInst *LI = dyn_cast<LoadInst>(Inst);
533         if (LI && ILS.count(LI) > 0)
534           return false;
535       }
536 
537       // Return true when Inst is defined inside the region R.
538       if (!Inst || !R->contains(Inst))
539         return true;
540 
541       HasInRegionDeps = true;
542       return false;
543     }
544 
545     if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
546       if (AllowLoops)
547         return true;
548 
549       auto *L = AddRec->getLoop();
550       if (R->contains(L) && !L->contains(Scope)) {
551         HasInRegionDeps = true;
552         return false;
553       }
554     }
555 
556     return true;
557   }
isDone()558   bool isDone() { return false; }
hasDependences()559   bool hasDependences() { return HasInRegionDeps; }
560 };
561 
562 namespace polly {
563 /// Find all loops referenced in SCEVAddRecExprs.
564 class SCEVFindLoops {
565   SetVector<const Loop *> &Loops;
566 
567 public:
SCEVFindLoops(SetVector<const Loop * > & Loops)568   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
569 
follow(const SCEV * S)570   bool follow(const SCEV *S) {
571     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
572       Loops.insert(AddRec->getLoop());
573     return true;
574   }
isDone()575   bool isDone() { return false; }
576 };
577 
findLoops(const SCEV * Expr,SetVector<const Loop * > & Loops)578 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
579   SCEVFindLoops FindLoops(Loops);
580   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
581   ST.visitAll(Expr);
582 }
583 
584 /// Find all values referenced in SCEVUnknowns.
585 class SCEVFindValues {
586   ScalarEvolution &SE;
587   SetVector<Value *> &Values;
588 
589 public:
SCEVFindValues(ScalarEvolution & SE,SetVector<Value * > & Values)590   SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
591       : SE(SE), Values(Values) {}
592 
follow(const SCEV * S)593   bool follow(const SCEV *S) {
594     const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
595     if (!Unknown)
596       return true;
597 
598     Values.insert(Unknown->getValue());
599     Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
600     if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
601                   Inst->getOpcode() != Instruction::SDiv))
602       return false;
603 
604     auto *Dividend = SE.getSCEV(Inst->getOperand(1));
605     if (!isa<SCEVConstant>(Dividend))
606       return false;
607 
608     auto *Divisor = SE.getSCEV(Inst->getOperand(0));
609     SCEVFindValues FindValues(SE, Values);
610     SCEVTraversal<SCEVFindValues> ST(FindValues);
611     ST.visitAll(Dividend);
612     ST.visitAll(Divisor);
613 
614     return false;
615   }
isDone()616   bool isDone() { return false; }
617 };
618 
findValues(const SCEV * Expr,ScalarEvolution & SE,SetVector<Value * > & Values)619 void findValues(const SCEV *Expr, ScalarEvolution &SE,
620                 SetVector<Value *> &Values) {
621   SCEVFindValues FindValues(SE, Values);
622   SCEVTraversal<SCEVFindValues> ST(FindValues);
623   ST.visitAll(Expr);
624 }
625 
hasIVParams(const SCEV * Expr)626 bool hasIVParams(const SCEV *Expr) {
627   SCEVHasIVParams HasIVParams;
628   SCEVTraversal<SCEVHasIVParams> ST(HasIVParams);
629   ST.visitAll(Expr);
630   return HasIVParams.hasIVParams();
631 }
632 
hasScalarDepsInsideRegion(const SCEV * Expr,const Region * R,llvm::Loop * Scope,bool AllowLoops,const InvariantLoadsSetTy & ILS)633 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
634                                llvm::Loop *Scope, bool AllowLoops,
635                                const InvariantLoadsSetTy &ILS) {
636   SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
637   SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
638   ST.visitAll(Expr);
639   return InRegionDeps.hasDependences();
640 }
641 
isAffineExpr(const Region * R,llvm::Loop * Scope,const SCEV * Expr,ScalarEvolution & SE,InvariantLoadsSetTy * ILS)642 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
643                   ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
644   if (isa<SCEVCouldNotCompute>(Expr))
645     return false;
646 
647   SCEVValidator Validator(R, Scope, SE, ILS);
648   LLVM_DEBUG({
649     dbgs() << "\n";
650     dbgs() << "Expr: " << *Expr << "\n";
651     dbgs() << "Region: " << R->getNameStr() << "\n";
652     dbgs() << " -> ";
653   });
654 
655   ValidatorResult Result = Validator.visit(Expr);
656 
657   LLVM_DEBUG({
658     if (Result.isValid())
659       dbgs() << "VALID\n";
660     dbgs() << "\n";
661   });
662 
663   return Result.isValid();
664 }
665 
isAffineExpr(Value * V,const Region * R,Loop * Scope,ScalarEvolution & SE,ParameterSetTy & Params)666 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
667                          ScalarEvolution &SE, ParameterSetTy &Params) {
668   auto *E = SE.getSCEV(V);
669   if (isa<SCEVCouldNotCompute>(E))
670     return false;
671 
672   SCEVValidator Validator(R, Scope, SE, nullptr);
673   ValidatorResult Result = Validator.visit(E);
674   if (!Result.isValid())
675     return false;
676 
677   auto ResultParams = Result.getParameters();
678   Params.insert(ResultParams.begin(), ResultParams.end());
679 
680   return true;
681 }
682 
isAffineConstraint(Value * V,const Region * R,llvm::Loop * Scope,ScalarEvolution & SE,ParameterSetTy & Params,bool OrExpr)683 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
684                         ScalarEvolution &SE, ParameterSetTy &Params,
685                         bool OrExpr) {
686   if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
687     return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
688                               true) &&
689            isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
690   } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
691     auto Opcode = BinOp->getOpcode();
692     if (Opcode == Instruction::And || Opcode == Instruction::Or)
693       return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
694                                 false) &&
695              isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
696                                 false);
697     /* Fall through */
698   }
699 
700   if (!OrExpr)
701     return false;
702 
703   return isAffineExpr(V, R, Scope, SE, Params);
704 }
705 
getParamsInAffineExpr(const Region * R,Loop * Scope,const SCEV * Expr,ScalarEvolution & SE)706 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
707                                      const SCEV *Expr, ScalarEvolution &SE) {
708   if (isa<SCEVCouldNotCompute>(Expr))
709     return ParameterSetTy();
710 
711   InvariantLoadsSetTy ILS;
712   SCEVValidator Validator(R, Scope, SE, &ILS);
713   ValidatorResult Result = Validator.visit(Expr);
714   assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
715 
716   return Result.getParameters();
717 }
718 
719 std::pair<const SCEVConstant *, const SCEV *>
extractConstantFactor(const SCEV * S,ScalarEvolution & SE)720 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
721   auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
722 
723   if (auto *Constant = dyn_cast<SCEVConstant>(S))
724     return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
725 
726   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
727   if (AddRec) {
728     auto *StartExpr = AddRec->getStart();
729     if (StartExpr->isZero()) {
730       auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
731       auto *LeftOverAddRec =
732           SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
733                            AddRec->getNoWrapFlags());
734       return std::make_pair(StepPair.first, LeftOverAddRec);
735     }
736     return std::make_pair(ConstPart, S);
737   }
738 
739   if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
740     SmallVector<const SCEV *, 4> LeftOvers;
741     auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
742     auto *Factor = Op0Pair.first;
743     if (SE.isKnownNegative(Factor)) {
744       Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
745       LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
746     } else {
747       LeftOvers.push_back(Op0Pair.second);
748     }
749 
750     for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
751       auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
752       // TODO: Use something smarter than equality here, e.g., gcd.
753       if (Factor == OpUPair.first)
754         LeftOvers.push_back(OpUPair.second);
755       else if (Factor == SE.getNegativeSCEV(OpUPair.first))
756         LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
757       else
758         return std::make_pair(ConstPart, S);
759     }
760 
761     auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
762     return std::make_pair(Factor, NewAdd);
763   }
764 
765   auto *Mul = dyn_cast<SCEVMulExpr>(S);
766   if (!Mul)
767     return std::make_pair(ConstPart, S);
768 
769   SmallVector<const SCEV *, 4> LeftOvers;
770   for (auto *Op : Mul->operands())
771     if (isa<SCEVConstant>(Op))
772       ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
773     else
774       LeftOvers.push_back(Op);
775 
776   return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
777 }
778 
tryForwardThroughPHI(const SCEV * Expr,Region & R,ScalarEvolution & SE,LoopInfo & LI,const DominatorTree & DT)779 const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R,
780                                  ScalarEvolution &SE, LoopInfo &LI,
781                                  const DominatorTree &DT) {
782   if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
783     Value *V = Unknown->getValue();
784     auto *PHI = dyn_cast<PHINode>(V);
785     if (!PHI)
786       return Expr;
787 
788     Value *Final = nullptr;
789 
790     for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
791       BasicBlock *Incoming = PHI->getIncomingBlock(i);
792       if (isErrorBlock(*Incoming, R, LI, DT) && R.contains(Incoming))
793         continue;
794       if (Final)
795         return Expr;
796       Final = PHI->getIncomingValue(i);
797     }
798 
799     if (Final)
800       return SE.getSCEV(Final);
801   }
802   return Expr;
803 }
804 
getUniqueNonErrorValue(PHINode * PHI,Region * R,LoopInfo & LI,const DominatorTree & DT)805 Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, LoopInfo &LI,
806                               const DominatorTree &DT) {
807   Value *V = nullptr;
808   for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
809     BasicBlock *BB = PHI->getIncomingBlock(i);
810     if (!isErrorBlock(*BB, *R, LI, DT)) {
811       if (V)
812         return nullptr;
813       V = PHI->getIncomingValue(i);
814     }
815   }
816 
817   return V;
818 }
819 } // namespace polly
820