1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopDetection.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8 
9 using namespace llvm;
10 using namespace polly;
11 
12 #define DEBUG_TYPE "polly-scev-validator"
13 
14 namespace SCEVType {
15 /// The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22   // An integer value.
23   INT,
24 
25   // An expression that is constant during the execution of the Scop,
26   // but that may depend on parameters unknown at compile time.
27   PARAM,
28 
29   // An expression that may change during the execution of the SCoP.
30   IV,
31 
32   // An invalid expression.
33   INVALID
34 };
35 } // namespace SCEVType
36 
37 /// The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39   /// The type of the expression
40   SCEVType::TYPE Type;
41 
42   /// The set of Parameters in the expression.
43   ParameterSetTy Parameters;
44 
45 public:
46   /// The copy constructor
ValidatorResult(const ValidatorResult & Source)47   ValidatorResult(const ValidatorResult &Source) {
48     Type = Source.Type;
49     Parameters = Source.Parameters;
50   }
51 
52   /// Construct a result with a certain type and no parameters.
ValidatorResult(SCEVType::TYPE Type)53   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55   }
56 
57   /// Construct a result with a certain type and a single parameter.
ValidatorResult(SCEVType::TYPE Type,const SCEV * Expr)58   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59     Parameters.insert(Expr);
60   }
61 
62   /// Get the type of the ValidatorResult.
getType()63   SCEVType::TYPE getType() { return Type; }
64 
65   /// Is the analyzed SCEV constant during the execution of the SCoP.
isConstant()66   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67 
68   /// Is the analyzed SCEV valid.
isValid()69   bool isValid() { return Type != SCEVType::INVALID; }
70 
71   /// Is the analyzed SCEV of Type IV.
isIV()72   bool isIV() { return Type == SCEVType::IV; }
73 
74   /// Is the analyzed SCEV of Type INT.
isINT()75   bool isINT() { return Type == SCEVType::INT; }
76 
77   /// Is the analyzed SCEV of Type PARAM.
isPARAM()78   bool isPARAM() { return Type == SCEVType::PARAM; }
79 
80   /// Get the parameters of this validator result.
getParameters()81   const ParameterSetTy &getParameters() { return Parameters; }
82 
83   /// Add the parameters of Source to this result.
addParamsFrom(const ValidatorResult & Source)84   void addParamsFrom(const ValidatorResult &Source) {
85     Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
86   }
87 
88   /// Merge a result.
89   ///
90   /// This means to merge the parameters and to set the Type to the most
91   /// specific Type that matches both.
merge(const ValidatorResult & ToMerge)92   void merge(const ValidatorResult &ToMerge) {
93     Type = std::max(Type, ToMerge.Type);
94     addParamsFrom(ToMerge);
95   }
96 
print(raw_ostream & OS)97   void print(raw_ostream &OS) {
98     switch (Type) {
99     case SCEVType::INT:
100       OS << "SCEVType::INT";
101       break;
102     case SCEVType::PARAM:
103       OS << "SCEVType::PARAM";
104       break;
105     case SCEVType::IV:
106       OS << "SCEVType::IV";
107       break;
108     case SCEVType::INVALID:
109       OS << "SCEVType::INVALID";
110       break;
111     }
112   }
113 };
114 
operator <<(raw_ostream & OS,class ValidatorResult & VR)115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
116   VR.print(OS);
117   return OS;
118 }
119 
isConstCall(llvm::CallInst * Call)120 bool polly::isConstCall(llvm::CallInst *Call) {
121   if (Call->mayReadOrWriteMemory())
122     return false;
123 
124   for (auto &Operand : Call->arg_operands())
125     if (!isa<ConstantInt>(&Operand))
126       return false;
127 
128   return true;
129 }
130 
131 /// Check if a SCEV is valid in a SCoP.
132 struct SCEVValidator
133     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
134 private:
135   const Region *R;
136   Loop *Scope;
137   ScalarEvolution &SE;
138   InvariantLoadsSetTy *ILS;
139 
140 public:
SCEVValidatorSCEVValidator141   SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
142                 InvariantLoadsSetTy *ILS)
143       : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
144 
visitConstantSCEVValidator145   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
146     return ValidatorResult(SCEVType::INT);
147   }
148 
visitZeroExtendOrTruncateExprSCEVValidator149   class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
150                                                       const SCEV *Operand) {
151     ValidatorResult Op = visit(Operand);
152     auto Type = Op.getType();
153 
154     // If unsigned operations are allowed return the operand, otherwise
155     // check if we can model the expression without unsigned assumptions.
156     if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
157       return Op;
158 
159     if (Type == SCEVType::IV)
160       return ValidatorResult(SCEVType::INVALID);
161     return ValidatorResult(SCEVType::PARAM, Expr);
162   }
163 
visitTruncateExprSCEVValidator164   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
165     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
166   }
167 
visitZeroExtendExprSCEVValidator168   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
169     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
170   }
171 
visitSignExtendExprSCEVValidator172   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
173     return visit(Expr->getOperand());
174   }
175 
visitAddExprSCEVValidator176   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
177     ValidatorResult Return(SCEVType::INT);
178 
179     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
180       ValidatorResult Op = visit(Expr->getOperand(i));
181       Return.merge(Op);
182 
183       // Early exit.
184       if (!Return.isValid())
185         break;
186     }
187 
188     return Return;
189   }
190 
visitMulExprSCEVValidator191   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
192     ValidatorResult Return(SCEVType::INT);
193 
194     bool HasMultipleParams = false;
195 
196     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
197       ValidatorResult Op = visit(Expr->getOperand(i));
198 
199       if (Op.isINT())
200         continue;
201 
202       if (Op.isPARAM() && Return.isPARAM()) {
203         HasMultipleParams = true;
204         continue;
205       }
206 
207       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
208         LLVM_DEBUG(
209             dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
210                    << "\tExpr: " << *Expr << "\n"
211                    << "\tPrevious expression type: " << Return << "\n"
212                    << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
213                    << "\n");
214 
215         return ValidatorResult(SCEVType::INVALID);
216       }
217 
218       Return.merge(Op);
219     }
220 
221     if (HasMultipleParams && Return.isValid())
222       return ValidatorResult(SCEVType::PARAM, Expr);
223 
224     return Return;
225   }
226 
visitAddRecExprSCEVValidator227   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
228     if (!Expr->isAffine()) {
229       LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
230       return ValidatorResult(SCEVType::INVALID);
231     }
232 
233     ValidatorResult Start = visit(Expr->getStart());
234     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
235 
236     if (!Start.isValid())
237       return Start;
238 
239     if (!Recurrence.isValid())
240       return Recurrence;
241 
242     auto *L = Expr->getLoop();
243     if (R->contains(L) && (!Scope || !L->contains(Scope))) {
244       LLVM_DEBUG(
245           dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
246                     "non-affine subregion or has a non-synthesizable exit "
247                     "value.");
248       return ValidatorResult(SCEVType::INVALID);
249     }
250 
251     if (R->contains(L)) {
252       if (Recurrence.isINT()) {
253         ValidatorResult Result(SCEVType::IV);
254         Result.addParamsFrom(Start);
255         return Result;
256       }
257 
258       LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
259                            "recurrence part");
260       return ValidatorResult(SCEVType::INVALID);
261     }
262 
263     assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
264 
265     // Directly generate ValidatorResult for Expr if 'start' is zero.
266     if (Expr->getStart()->isZero())
267       return ValidatorResult(SCEVType::PARAM, Expr);
268 
269     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
270     // if 'start' is not zero.
271     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
272         SE.getConstant(Expr->getStart()->getType(), 0),
273         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
274 
275     ValidatorResult ZeroStartResult =
276         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
277     ZeroStartResult.addParamsFrom(Start);
278 
279     return ZeroStartResult;
280   }
281 
visitSMaxExprSCEVValidator282   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
283     ValidatorResult Return(SCEVType::INT);
284 
285     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
286       ValidatorResult Op = visit(Expr->getOperand(i));
287 
288       if (!Op.isValid())
289         return Op;
290 
291       Return.merge(Op);
292     }
293 
294     return Return;
295   }
296 
visitSMinExprSCEVValidator297   class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
298     ValidatorResult Return(SCEVType::INT);
299 
300     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
301       ValidatorResult Op = visit(Expr->getOperand(i));
302 
303       if (!Op.isValid())
304         return Op;
305 
306       Return.merge(Op);
307     }
308 
309     return Return;
310   }
311 
visitUMaxExprSCEVValidator312   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
313     // We do not support unsigned max operations. If 'Expr' is constant during
314     // Scop execution we treat this as a parameter, otherwise we bail out.
315     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
316       ValidatorResult Op = visit(Expr->getOperand(i));
317 
318       if (!Op.isConstant()) {
319         LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
320         return ValidatorResult(SCEVType::INVALID);
321       }
322     }
323 
324     return ValidatorResult(SCEVType::PARAM, Expr);
325   }
326 
visitUMinExprSCEVValidator327   class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
328     // We do not support unsigned min operations. If 'Expr' is constant during
329     // Scop execution we treat this as a parameter, otherwise we bail out.
330     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
331       ValidatorResult Op = visit(Expr->getOperand(i));
332 
333       if (!Op.isConstant()) {
334         LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
335         return ValidatorResult(SCEVType::INVALID);
336       }
337     }
338 
339     return ValidatorResult(SCEVType::PARAM, Expr);
340   }
341 
visitGenericInstSCEVValidator342   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
343     if (R->contains(I)) {
344       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
345                            "within the region\n");
346       return ValidatorResult(SCEVType::INVALID);
347     }
348 
349     return ValidatorResult(SCEVType::PARAM, S);
350   }
351 
visitCallInstructionSCEVValidator352   ValidatorResult visitCallInstruction(Instruction *I, const SCEV *S) {
353     assert(I->getOpcode() == Instruction::Call && "Call instruction expected");
354 
355     if (R->contains(I)) {
356       auto Call = cast<CallInst>(I);
357 
358       if (!isConstCall(Call))
359         return ValidatorResult(SCEVType::INVALID, S);
360     }
361     return ValidatorResult(SCEVType::PARAM, S);
362   }
363 
visitLoadInstructionSCEVValidator364   ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
365     if (R->contains(I) && ILS) {
366       ILS->insert(cast<LoadInst>(I));
367       return ValidatorResult(SCEVType::PARAM, S);
368     }
369 
370     return visitGenericInst(I, S);
371   }
372 
visitDivisionSCEVValidator373   ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
374                                 const SCEV *DivExpr,
375                                 Instruction *SDiv = nullptr) {
376 
377     // First check if we might be able to model the division, thus if the
378     // divisor is constant. If so, check the dividend, otherwise check if
379     // the whole division can be seen as a parameter.
380     if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
381       return visit(Dividend);
382 
383     // For signed divisions use the SDiv instruction to check for a parameter
384     // division, for unsigned divisions check the operands.
385     if (SDiv)
386       return visitGenericInst(SDiv, DivExpr);
387 
388     ValidatorResult LHS = visit(Dividend);
389     ValidatorResult RHS = visit(Divisor);
390     if (LHS.isConstant() && RHS.isConstant())
391       return ValidatorResult(SCEVType::PARAM, DivExpr);
392 
393     LLVM_DEBUG(
394         dbgs() << "INVALID: unsigned division of non-constant expressions");
395     return ValidatorResult(SCEVType::INVALID);
396   }
397 
visitUDivExprSCEVValidator398   ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
399     if (!PollyAllowUnsignedOperations)
400       return ValidatorResult(SCEVType::INVALID);
401 
402     auto *Dividend = Expr->getLHS();
403     auto *Divisor = Expr->getRHS();
404     return visitDivision(Dividend, Divisor, Expr);
405   }
406 
visitSDivInstructionSCEVValidator407   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
408     assert(SDiv->getOpcode() == Instruction::SDiv &&
409            "Assumed SDiv instruction!");
410 
411     auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
412     auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
413     return visitDivision(Dividend, Divisor, Expr, SDiv);
414   }
415 
visitSRemInstructionSCEVValidator416   ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
417     assert(SRem->getOpcode() == Instruction::SRem &&
418            "Assumed SRem instruction!");
419 
420     auto *Divisor = SRem->getOperand(1);
421     auto *CI = dyn_cast<ConstantInt>(Divisor);
422     if (!CI || CI->isZeroValue())
423       return visitGenericInst(SRem, S);
424 
425     auto *Dividend = SRem->getOperand(0);
426     auto *DividendSCEV = SE.getSCEV(Dividend);
427     return visit(DividendSCEV);
428   }
429 
visitUnknownSCEVValidator430   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
431     Value *V = Expr->getValue();
432 
433     if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
434       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
435       return ValidatorResult(SCEVType::INVALID);
436     }
437 
438     if (isa<UndefValue>(V)) {
439       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
440       return ValidatorResult(SCEVType::INVALID);
441     }
442 
443     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
444       switch (I->getOpcode()) {
445       case Instruction::IntToPtr:
446         return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
447       case Instruction::PtrToInt:
448         return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
449       case Instruction::Load:
450         return visitLoadInstruction(I, Expr);
451       case Instruction::SDiv:
452         return visitSDivInstruction(I, Expr);
453       case Instruction::SRem:
454         return visitSRemInstruction(I, Expr);
455       case Instruction::Call:
456         return visitCallInstruction(I, Expr);
457       default:
458         return visitGenericInst(I, Expr);
459       }
460     }
461 
462     return ValidatorResult(SCEVType::PARAM, Expr);
463   }
464 };
465 
466 class SCEVHasIVParams {
467   bool HasIVParams = false;
468 
469 public:
SCEVHasIVParams()470   SCEVHasIVParams() {}
471 
follow(const SCEV * S)472   bool follow(const SCEV *S) {
473     const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
474     if (!Unknown)
475       return true;
476 
477     CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
478 
479     if (!Call)
480       return true;
481 
482     if (isConstCall(Call)) {
483       HasIVParams = true;
484       return false;
485     }
486 
487     return true;
488   }
489 
isDone()490   bool isDone() { return HasIVParams; }
hasIVParams()491   bool hasIVParams() { return HasIVParams; }
492 };
493 
494 /// Check whether a SCEV refers to an SSA name defined inside a region.
495 class SCEVInRegionDependences {
496   const Region *R;
497   Loop *Scope;
498   const InvariantLoadsSetTy &ILS;
499   bool AllowLoops;
500   bool HasInRegionDeps = false;
501 
502 public:
SCEVInRegionDependences(const Region * R,Loop * Scope,bool AllowLoops,const InvariantLoadsSetTy & ILS)503   SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
504                           const InvariantLoadsSetTy &ILS)
505       : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
506 
follow(const SCEV * S)507   bool follow(const SCEV *S) {
508     if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
509       Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
510 
511       CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
512 
513       if (Call && isConstCall(Call))
514         return false;
515 
516       if (Inst) {
517         // When we invariant load hoist a load, we first make sure that there
518         // can be no dependences created by it in the Scop region. So, we should
519         // not consider scalar dependences to `LoadInst`s that are invariant
520         // load hoisted.
521         //
522         // If this check is not present, then we create data dependences which
523         // are strictly not necessary by tracking the invariant load as a
524         // scalar.
525         LoadInst *LI = dyn_cast<LoadInst>(Inst);
526         if (LI && ILS.count(LI) > 0)
527           return false;
528       }
529 
530       // Return true when Inst is defined inside the region R.
531       if (!Inst || !R->contains(Inst))
532         return true;
533 
534       HasInRegionDeps = true;
535       return false;
536     }
537 
538     if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
539       if (AllowLoops)
540         return true;
541 
542       auto *L = AddRec->getLoop();
543       if (R->contains(L) && !L->contains(Scope)) {
544         HasInRegionDeps = true;
545         return false;
546       }
547     }
548 
549     return true;
550   }
isDone()551   bool isDone() { return false; }
hasDependences()552   bool hasDependences() { return HasInRegionDeps; }
553 };
554 
555 namespace polly {
556 /// Find all loops referenced in SCEVAddRecExprs.
557 class SCEVFindLoops {
558   SetVector<const Loop *> &Loops;
559 
560 public:
SCEVFindLoops(SetVector<const Loop * > & Loops)561   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
562 
follow(const SCEV * S)563   bool follow(const SCEV *S) {
564     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
565       Loops.insert(AddRec->getLoop());
566     return true;
567   }
isDone()568   bool isDone() { return false; }
569 };
570 
findLoops(const SCEV * Expr,SetVector<const Loop * > & Loops)571 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
572   SCEVFindLoops FindLoops(Loops);
573   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
574   ST.visitAll(Expr);
575 }
576 
577 /// Find all values referenced in SCEVUnknowns.
578 class SCEVFindValues {
579   ScalarEvolution &SE;
580   SetVector<Value *> &Values;
581 
582 public:
SCEVFindValues(ScalarEvolution & SE,SetVector<Value * > & Values)583   SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
584       : SE(SE), Values(Values) {}
585 
follow(const SCEV * S)586   bool follow(const SCEV *S) {
587     const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
588     if (!Unknown)
589       return true;
590 
591     Values.insert(Unknown->getValue());
592     Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
593     if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
594                   Inst->getOpcode() != Instruction::SDiv))
595       return false;
596 
597     auto *Dividend = SE.getSCEV(Inst->getOperand(1));
598     if (!isa<SCEVConstant>(Dividend))
599       return false;
600 
601     auto *Divisor = SE.getSCEV(Inst->getOperand(0));
602     SCEVFindValues FindValues(SE, Values);
603     SCEVTraversal<SCEVFindValues> ST(FindValues);
604     ST.visitAll(Dividend);
605     ST.visitAll(Divisor);
606 
607     return false;
608   }
isDone()609   bool isDone() { return false; }
610 };
611 
findValues(const SCEV * Expr,ScalarEvolution & SE,SetVector<Value * > & Values)612 void findValues(const SCEV *Expr, ScalarEvolution &SE,
613                 SetVector<Value *> &Values) {
614   SCEVFindValues FindValues(SE, Values);
615   SCEVTraversal<SCEVFindValues> ST(FindValues);
616   ST.visitAll(Expr);
617 }
618 
hasIVParams(const SCEV * Expr)619 bool hasIVParams(const SCEV *Expr) {
620   SCEVHasIVParams HasIVParams;
621   SCEVTraversal<SCEVHasIVParams> ST(HasIVParams);
622   ST.visitAll(Expr);
623   return HasIVParams.hasIVParams();
624 }
625 
hasScalarDepsInsideRegion(const SCEV * Expr,const Region * R,llvm::Loop * Scope,bool AllowLoops,const InvariantLoadsSetTy & ILS)626 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
627                                llvm::Loop *Scope, bool AllowLoops,
628                                const InvariantLoadsSetTy &ILS) {
629   SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
630   SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
631   ST.visitAll(Expr);
632   return InRegionDeps.hasDependences();
633 }
634 
isAffineExpr(const Region * R,llvm::Loop * Scope,const SCEV * Expr,ScalarEvolution & SE,InvariantLoadsSetTy * ILS)635 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
636                   ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
637   if (isa<SCEVCouldNotCompute>(Expr))
638     return false;
639 
640   SCEVValidator Validator(R, Scope, SE, ILS);
641   LLVM_DEBUG({
642     dbgs() << "\n";
643     dbgs() << "Expr: " << *Expr << "\n";
644     dbgs() << "Region: " << R->getNameStr() << "\n";
645     dbgs() << " -> ";
646   });
647 
648   ValidatorResult Result = Validator.visit(Expr);
649 
650   LLVM_DEBUG({
651     if (Result.isValid())
652       dbgs() << "VALID\n";
653     dbgs() << "\n";
654   });
655 
656   return Result.isValid();
657 }
658 
isAffineExpr(Value * V,const Region * R,Loop * Scope,ScalarEvolution & SE,ParameterSetTy & Params)659 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
660                          ScalarEvolution &SE, ParameterSetTy &Params) {
661   auto *E = SE.getSCEV(V);
662   if (isa<SCEVCouldNotCompute>(E))
663     return false;
664 
665   SCEVValidator Validator(R, Scope, SE, nullptr);
666   ValidatorResult Result = Validator.visit(E);
667   if (!Result.isValid())
668     return false;
669 
670   auto ResultParams = Result.getParameters();
671   Params.insert(ResultParams.begin(), ResultParams.end());
672 
673   return true;
674 }
675 
isAffineConstraint(Value * V,const Region * R,llvm::Loop * Scope,ScalarEvolution & SE,ParameterSetTy & Params,bool OrExpr)676 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
677                         ScalarEvolution &SE, ParameterSetTy &Params,
678                         bool OrExpr) {
679   if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
680     return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
681                               true) &&
682            isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
683   } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
684     auto Opcode = BinOp->getOpcode();
685     if (Opcode == Instruction::And || Opcode == Instruction::Or)
686       return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
687                                 false) &&
688              isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
689                                 false);
690     /* Fall through */
691   }
692 
693   if (!OrExpr)
694     return false;
695 
696   return isAffineExpr(V, R, Scope, SE, Params);
697 }
698 
getParamsInAffineExpr(const Region * R,Loop * Scope,const SCEV * Expr,ScalarEvolution & SE)699 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
700                                      const SCEV *Expr, ScalarEvolution &SE) {
701   if (isa<SCEVCouldNotCompute>(Expr))
702     return ParameterSetTy();
703 
704   InvariantLoadsSetTy ILS;
705   SCEVValidator Validator(R, Scope, SE, &ILS);
706   ValidatorResult Result = Validator.visit(Expr);
707   assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
708 
709   return Result.getParameters();
710 }
711 
712 std::pair<const SCEVConstant *, const SCEV *>
extractConstantFactor(const SCEV * S,ScalarEvolution & SE)713 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
714   auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
715 
716   if (auto *Constant = dyn_cast<SCEVConstant>(S))
717     return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
718 
719   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
720   if (AddRec) {
721     auto *StartExpr = AddRec->getStart();
722     if (StartExpr->isZero()) {
723       auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
724       auto *LeftOverAddRec =
725           SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
726                            AddRec->getNoWrapFlags());
727       return std::make_pair(StepPair.first, LeftOverAddRec);
728     }
729     return std::make_pair(ConstPart, S);
730   }
731 
732   if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
733     SmallVector<const SCEV *, 4> LeftOvers;
734     auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
735     auto *Factor = Op0Pair.first;
736     if (SE.isKnownNegative(Factor)) {
737       Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
738       LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
739     } else {
740       LeftOvers.push_back(Op0Pair.second);
741     }
742 
743     for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
744       auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
745       // TODO: Use something smarter than equality here, e.g., gcd.
746       if (Factor == OpUPair.first)
747         LeftOvers.push_back(OpUPair.second);
748       else if (Factor == SE.getNegativeSCEV(OpUPair.first))
749         LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
750       else
751         return std::make_pair(ConstPart, S);
752     }
753 
754     auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
755     return std::make_pair(Factor, NewAdd);
756   }
757 
758   auto *Mul = dyn_cast<SCEVMulExpr>(S);
759   if (!Mul)
760     return std::make_pair(ConstPart, S);
761 
762   SmallVector<const SCEV *, 4> LeftOvers;
763   for (auto *Op : Mul->operands())
764     if (isa<SCEVConstant>(Op))
765       ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
766     else
767       LeftOvers.push_back(Op);
768 
769   return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
770 }
771 
tryForwardThroughPHI(const SCEV * Expr,Region & R,ScalarEvolution & SE,LoopInfo & LI,const DominatorTree & DT)772 const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R,
773                                  ScalarEvolution &SE, LoopInfo &LI,
774                                  const DominatorTree &DT) {
775   if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
776     Value *V = Unknown->getValue();
777     auto *PHI = dyn_cast<PHINode>(V);
778     if (!PHI)
779       return Expr;
780 
781     Value *Final = nullptr;
782 
783     for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
784       BasicBlock *Incoming = PHI->getIncomingBlock(i);
785       if (isErrorBlock(*Incoming, R, LI, DT) && R.contains(Incoming))
786         continue;
787       if (Final)
788         return Expr;
789       Final = PHI->getIncomingValue(i);
790     }
791 
792     if (Final)
793       return SE.getSCEV(Final);
794   }
795   return Expr;
796 }
797 
getUniqueNonErrorValue(PHINode * PHI,Region * R,LoopInfo & LI,const DominatorTree & DT)798 Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, LoopInfo &LI,
799                               const DominatorTree &DT) {
800   Value *V = nullptr;
801   for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
802     BasicBlock *BB = PHI->getIncomingBlock(i);
803     if (!isErrorBlock(*BB, *R, LI, DT)) {
804       if (V)
805         return nullptr;
806       V = PHI->getIncomingValue(i);
807     }
808   }
809 
810   return V;
811 }
812 } // namespace polly
813