1
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopDetection.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8
9 using namespace llvm;
10 using namespace polly;
11
12 #define DEBUG_TYPE "polly-scev-validator"
13
14 namespace SCEVType {
15 /// The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22 // An integer value.
23 INT,
24
25 // An expression that is constant during the execution of the Scop,
26 // but that may depend on parameters unknown at compile time.
27 PARAM,
28
29 // An expression that may change during the execution of the SCoP.
30 IV,
31
32 // An invalid expression.
33 INVALID
34 };
35 } // namespace SCEVType
36
37 /// The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39 /// The type of the expression
40 SCEVType::TYPE Type;
41
42 /// The set of Parameters in the expression.
43 ParameterSetTy Parameters;
44
45 public:
46 /// The copy constructor
ValidatorResult(const ValidatorResult & Source)47 ValidatorResult(const ValidatorResult &Source) {
48 Type = Source.Type;
49 Parameters = Source.Parameters;
50 }
51
52 /// Construct a result with a certain type and no parameters.
ValidatorResult(SCEVType::TYPE Type)53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55 }
56
57 /// Construct a result with a certain type and a single parameter.
ValidatorResult(SCEVType::TYPE Type,const SCEV * Expr)58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59 Parameters.insert(Expr);
60 }
61
62 /// Get the type of the ValidatorResult.
getType()63 SCEVType::TYPE getType() { return Type; }
64
65 /// Is the analyzed SCEV constant during the execution of the SCoP.
isConstant()66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67
68 /// Is the analyzed SCEV valid.
isValid()69 bool isValid() { return Type != SCEVType::INVALID; }
70
71 /// Is the analyzed SCEV of Type IV.
isIV()72 bool isIV() { return Type == SCEVType::IV; }
73
74 /// Is the analyzed SCEV of Type INT.
isINT()75 bool isINT() { return Type == SCEVType::INT; }
76
77 /// Is the analyzed SCEV of Type PARAM.
isPARAM()78 bool isPARAM() { return Type == SCEVType::PARAM; }
79
80 /// Get the parameters of this validator result.
getParameters()81 const ParameterSetTy &getParameters() { return Parameters; }
82
83 /// Add the parameters of Source to this result.
addParamsFrom(const ValidatorResult & Source)84 void addParamsFrom(const ValidatorResult &Source) {
85 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
86 }
87
88 /// Merge a result.
89 ///
90 /// This means to merge the parameters and to set the Type to the most
91 /// specific Type that matches both.
merge(const ValidatorResult & ToMerge)92 void merge(const ValidatorResult &ToMerge) {
93 Type = std::max(Type, ToMerge.Type);
94 addParamsFrom(ToMerge);
95 }
96
print(raw_ostream & OS)97 void print(raw_ostream &OS) {
98 switch (Type) {
99 case SCEVType::INT:
100 OS << "SCEVType::INT";
101 break;
102 case SCEVType::PARAM:
103 OS << "SCEVType::PARAM";
104 break;
105 case SCEVType::IV:
106 OS << "SCEVType::IV";
107 break;
108 case SCEVType::INVALID:
109 OS << "SCEVType::INVALID";
110 break;
111 }
112 }
113 };
114
operator <<(raw_ostream & OS,class ValidatorResult & VR)115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
116 VR.print(OS);
117 return OS;
118 }
119
isConstCall(llvm::CallInst * Call)120 bool polly::isConstCall(llvm::CallInst *Call) {
121 if (Call->mayReadOrWriteMemory())
122 return false;
123
124 for (auto &Operand : Call->arg_operands())
125 if (!isa<ConstantInt>(&Operand))
126 return false;
127
128 return true;
129 }
130
131 /// Check if a SCEV is valid in a SCoP.
132 struct SCEVValidator
133 : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
134 private:
135 const Region *R;
136 Loop *Scope;
137 ScalarEvolution &SE;
138 InvariantLoadsSetTy *ILS;
139
140 public:
SCEVValidatorSCEVValidator141 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
142 InvariantLoadsSetTy *ILS)
143 : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
144
visitConstantSCEVValidator145 class ValidatorResult visitConstant(const SCEVConstant *Constant) {
146 return ValidatorResult(SCEVType::INT);
147 }
148
visitZeroExtendOrTruncateExprSCEVValidator149 class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
150 const SCEV *Operand) {
151 ValidatorResult Op = visit(Operand);
152 auto Type = Op.getType();
153
154 // If unsigned operations are allowed return the operand, otherwise
155 // check if we can model the expression without unsigned assumptions.
156 if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
157 return Op;
158
159 if (Type == SCEVType::IV)
160 return ValidatorResult(SCEVType::INVALID);
161 return ValidatorResult(SCEVType::PARAM, Expr);
162 }
163
visitPtrToIntExprSCEVValidator164 class ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
165 return visit(Expr->getOperand());
166 }
167
visitTruncateExprSCEVValidator168 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
169 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
170 }
171
visitZeroExtendExprSCEVValidator172 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
173 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
174 }
175
visitSignExtendExprSCEVValidator176 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
177 return visit(Expr->getOperand());
178 }
179
visitAddExprSCEVValidator180 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
181 ValidatorResult Return(SCEVType::INT);
182
183 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
184 ValidatorResult Op = visit(Expr->getOperand(i));
185 Return.merge(Op);
186
187 // Early exit.
188 if (!Return.isValid())
189 break;
190 }
191
192 return Return;
193 }
194
visitMulExprSCEVValidator195 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
196 ValidatorResult Return(SCEVType::INT);
197
198 bool HasMultipleParams = false;
199
200 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
201 ValidatorResult Op = visit(Expr->getOperand(i));
202
203 if (Op.isINT())
204 continue;
205
206 if (Op.isPARAM() && Return.isPARAM()) {
207 HasMultipleParams = true;
208 continue;
209 }
210
211 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
212 LLVM_DEBUG(
213 dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
214 << "\tExpr: " << *Expr << "\n"
215 << "\tPrevious expression type: " << Return << "\n"
216 << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
217 << "\n");
218
219 return ValidatorResult(SCEVType::INVALID);
220 }
221
222 Return.merge(Op);
223 }
224
225 if (HasMultipleParams && Return.isValid())
226 return ValidatorResult(SCEVType::PARAM, Expr);
227
228 return Return;
229 }
230
visitAddRecExprSCEVValidator231 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
232 if (!Expr->isAffine()) {
233 LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
234 return ValidatorResult(SCEVType::INVALID);
235 }
236
237 ValidatorResult Start = visit(Expr->getStart());
238 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
239
240 if (!Start.isValid())
241 return Start;
242
243 if (!Recurrence.isValid())
244 return Recurrence;
245
246 auto *L = Expr->getLoop();
247 if (R->contains(L) && (!Scope || !L->contains(Scope))) {
248 LLVM_DEBUG(
249 dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
250 "non-affine subregion or has a non-synthesizable exit "
251 "value.");
252 return ValidatorResult(SCEVType::INVALID);
253 }
254
255 if (R->contains(L)) {
256 if (Recurrence.isINT()) {
257 ValidatorResult Result(SCEVType::IV);
258 Result.addParamsFrom(Start);
259 return Result;
260 }
261
262 LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
263 "recurrence part");
264 return ValidatorResult(SCEVType::INVALID);
265 }
266
267 assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
268
269 // Directly generate ValidatorResult for Expr if 'start' is zero.
270 if (Expr->getStart()->isZero())
271 return ValidatorResult(SCEVType::PARAM, Expr);
272
273 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
274 // if 'start' is not zero.
275 const SCEV *ZeroStartExpr = SE.getAddRecExpr(
276 SE.getConstant(Expr->getStart()->getType(), 0),
277 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
278
279 ValidatorResult ZeroStartResult =
280 ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
281 ZeroStartResult.addParamsFrom(Start);
282
283 return ZeroStartResult;
284 }
285
visitSMaxExprSCEVValidator286 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
287 ValidatorResult Return(SCEVType::INT);
288
289 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
290 ValidatorResult Op = visit(Expr->getOperand(i));
291
292 if (!Op.isValid())
293 return Op;
294
295 Return.merge(Op);
296 }
297
298 return Return;
299 }
300
visitSMinExprSCEVValidator301 class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
302 ValidatorResult Return(SCEVType::INT);
303
304 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
305 ValidatorResult Op = visit(Expr->getOperand(i));
306
307 if (!Op.isValid())
308 return Op;
309
310 Return.merge(Op);
311 }
312
313 return Return;
314 }
315
visitUMaxExprSCEVValidator316 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
317 // We do not support unsigned max operations. If 'Expr' is constant during
318 // Scop execution we treat this as a parameter, otherwise we bail out.
319 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
320 ValidatorResult Op = visit(Expr->getOperand(i));
321
322 if (!Op.isConstant()) {
323 LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
324 return ValidatorResult(SCEVType::INVALID);
325 }
326 }
327
328 return ValidatorResult(SCEVType::PARAM, Expr);
329 }
330
visitUMinExprSCEVValidator331 class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
332 // We do not support unsigned min operations. If 'Expr' is constant during
333 // Scop execution we treat this as a parameter, otherwise we bail out.
334 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
335 ValidatorResult Op = visit(Expr->getOperand(i));
336
337 if (!Op.isConstant()) {
338 LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
339 return ValidatorResult(SCEVType::INVALID);
340 }
341 }
342
343 return ValidatorResult(SCEVType::PARAM, Expr);
344 }
345
visitGenericInstSCEVValidator346 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
347 if (R->contains(I)) {
348 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
349 "within the region\n");
350 return ValidatorResult(SCEVType::INVALID);
351 }
352
353 return ValidatorResult(SCEVType::PARAM, S);
354 }
355
visitCallInstructionSCEVValidator356 ValidatorResult visitCallInstruction(Instruction *I, const SCEV *S) {
357 assert(I->getOpcode() == Instruction::Call && "Call instruction expected");
358
359 if (R->contains(I)) {
360 auto Call = cast<CallInst>(I);
361
362 if (!isConstCall(Call))
363 return ValidatorResult(SCEVType::INVALID, S);
364 }
365 return ValidatorResult(SCEVType::PARAM, S);
366 }
367
visitLoadInstructionSCEVValidator368 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
369 if (R->contains(I) && ILS) {
370 ILS->insert(cast<LoadInst>(I));
371 return ValidatorResult(SCEVType::PARAM, S);
372 }
373
374 return visitGenericInst(I, S);
375 }
376
visitDivisionSCEVValidator377 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
378 const SCEV *DivExpr,
379 Instruction *SDiv = nullptr) {
380
381 // First check if we might be able to model the division, thus if the
382 // divisor is constant. If so, check the dividend, otherwise check if
383 // the whole division can be seen as a parameter.
384 if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
385 return visit(Dividend);
386
387 // For signed divisions use the SDiv instruction to check for a parameter
388 // division, for unsigned divisions check the operands.
389 if (SDiv)
390 return visitGenericInst(SDiv, DivExpr);
391
392 ValidatorResult LHS = visit(Dividend);
393 ValidatorResult RHS = visit(Divisor);
394 if (LHS.isConstant() && RHS.isConstant())
395 return ValidatorResult(SCEVType::PARAM, DivExpr);
396
397 LLVM_DEBUG(
398 dbgs() << "INVALID: unsigned division of non-constant expressions");
399 return ValidatorResult(SCEVType::INVALID);
400 }
401
visitUDivExprSCEVValidator402 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
403 if (!PollyAllowUnsignedOperations)
404 return ValidatorResult(SCEVType::INVALID);
405
406 auto *Dividend = Expr->getLHS();
407 auto *Divisor = Expr->getRHS();
408 return visitDivision(Dividend, Divisor, Expr);
409 }
410
visitSDivInstructionSCEVValidator411 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
412 assert(SDiv->getOpcode() == Instruction::SDiv &&
413 "Assumed SDiv instruction!");
414
415 auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
416 auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
417 return visitDivision(Dividend, Divisor, Expr, SDiv);
418 }
419
visitSRemInstructionSCEVValidator420 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
421 assert(SRem->getOpcode() == Instruction::SRem &&
422 "Assumed SRem instruction!");
423
424 auto *Divisor = SRem->getOperand(1);
425 auto *CI = dyn_cast<ConstantInt>(Divisor);
426 if (!CI || CI->isZeroValue())
427 return visitGenericInst(SRem, S);
428
429 auto *Dividend = SRem->getOperand(0);
430 auto *DividendSCEV = SE.getSCEV(Dividend);
431 return visit(DividendSCEV);
432 }
433
visitUnknownSCEVValidator434 ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
435 Value *V = Expr->getValue();
436
437 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
438 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
439 return ValidatorResult(SCEVType::INVALID);
440 }
441
442 if (isa<UndefValue>(V)) {
443 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
444 return ValidatorResult(SCEVType::INVALID);
445 }
446
447 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
448 switch (I->getOpcode()) {
449 case Instruction::IntToPtr:
450 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
451 case Instruction::Load:
452 return visitLoadInstruction(I, Expr);
453 case Instruction::SDiv:
454 return visitSDivInstruction(I, Expr);
455 case Instruction::SRem:
456 return visitSRemInstruction(I, Expr);
457 case Instruction::Call:
458 return visitCallInstruction(I, Expr);
459 default:
460 return visitGenericInst(I, Expr);
461 }
462 }
463
464 if (Expr->getType()->isPointerTy()) {
465 if (isa<ConstantPointerNull>(V))
466 return ValidatorResult(SCEVType::INT); // "int"
467 }
468
469 return ValidatorResult(SCEVType::PARAM, Expr);
470 }
471 };
472
473 class SCEVHasIVParams {
474 bool HasIVParams = false;
475
476 public:
SCEVHasIVParams()477 SCEVHasIVParams() {}
478
follow(const SCEV * S)479 bool follow(const SCEV *S) {
480 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
481 if (!Unknown)
482 return true;
483
484 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
485
486 if (!Call)
487 return true;
488
489 if (isConstCall(Call)) {
490 HasIVParams = true;
491 return false;
492 }
493
494 return true;
495 }
496
isDone()497 bool isDone() { return HasIVParams; }
hasIVParams()498 bool hasIVParams() { return HasIVParams; }
499 };
500
501 /// Check whether a SCEV refers to an SSA name defined inside a region.
502 class SCEVInRegionDependences {
503 const Region *R;
504 Loop *Scope;
505 const InvariantLoadsSetTy &ILS;
506 bool AllowLoops;
507 bool HasInRegionDeps = false;
508
509 public:
SCEVInRegionDependences(const Region * R,Loop * Scope,bool AllowLoops,const InvariantLoadsSetTy & ILS)510 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
511 const InvariantLoadsSetTy &ILS)
512 : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
513
follow(const SCEV * S)514 bool follow(const SCEV *S) {
515 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
516 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
517
518 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
519
520 if (Call && isConstCall(Call))
521 return false;
522
523 if (Inst) {
524 // When we invariant load hoist a load, we first make sure that there
525 // can be no dependences created by it in the Scop region. So, we should
526 // not consider scalar dependences to `LoadInst`s that are invariant
527 // load hoisted.
528 //
529 // If this check is not present, then we create data dependences which
530 // are strictly not necessary by tracking the invariant load as a
531 // scalar.
532 LoadInst *LI = dyn_cast<LoadInst>(Inst);
533 if (LI && ILS.count(LI) > 0)
534 return false;
535 }
536
537 // Return true when Inst is defined inside the region R.
538 if (!Inst || !R->contains(Inst))
539 return true;
540
541 HasInRegionDeps = true;
542 return false;
543 }
544
545 if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
546 if (AllowLoops)
547 return true;
548
549 auto *L = AddRec->getLoop();
550 if (R->contains(L) && !L->contains(Scope)) {
551 HasInRegionDeps = true;
552 return false;
553 }
554 }
555
556 return true;
557 }
isDone()558 bool isDone() { return false; }
hasDependences()559 bool hasDependences() { return HasInRegionDeps; }
560 };
561
562 namespace polly {
563 /// Find all loops referenced in SCEVAddRecExprs.
564 class SCEVFindLoops {
565 SetVector<const Loop *> &Loops;
566
567 public:
SCEVFindLoops(SetVector<const Loop * > & Loops)568 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
569
follow(const SCEV * S)570 bool follow(const SCEV *S) {
571 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
572 Loops.insert(AddRec->getLoop());
573 return true;
574 }
isDone()575 bool isDone() { return false; }
576 };
577
findLoops(const SCEV * Expr,SetVector<const Loop * > & Loops)578 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
579 SCEVFindLoops FindLoops(Loops);
580 SCEVTraversal<SCEVFindLoops> ST(FindLoops);
581 ST.visitAll(Expr);
582 }
583
584 /// Find all values referenced in SCEVUnknowns.
585 class SCEVFindValues {
586 ScalarEvolution &SE;
587 SetVector<Value *> &Values;
588
589 public:
SCEVFindValues(ScalarEvolution & SE,SetVector<Value * > & Values)590 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
591 : SE(SE), Values(Values) {}
592
follow(const SCEV * S)593 bool follow(const SCEV *S) {
594 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
595 if (!Unknown)
596 return true;
597
598 Values.insert(Unknown->getValue());
599 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
600 if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
601 Inst->getOpcode() != Instruction::SDiv))
602 return false;
603
604 auto *Dividend = SE.getSCEV(Inst->getOperand(1));
605 if (!isa<SCEVConstant>(Dividend))
606 return false;
607
608 auto *Divisor = SE.getSCEV(Inst->getOperand(0));
609 SCEVFindValues FindValues(SE, Values);
610 SCEVTraversal<SCEVFindValues> ST(FindValues);
611 ST.visitAll(Dividend);
612 ST.visitAll(Divisor);
613
614 return false;
615 }
isDone()616 bool isDone() { return false; }
617 };
618
findValues(const SCEV * Expr,ScalarEvolution & SE,SetVector<Value * > & Values)619 void findValues(const SCEV *Expr, ScalarEvolution &SE,
620 SetVector<Value *> &Values) {
621 SCEVFindValues FindValues(SE, Values);
622 SCEVTraversal<SCEVFindValues> ST(FindValues);
623 ST.visitAll(Expr);
624 }
625
hasIVParams(const SCEV * Expr)626 bool hasIVParams(const SCEV *Expr) {
627 SCEVHasIVParams HasIVParams;
628 SCEVTraversal<SCEVHasIVParams> ST(HasIVParams);
629 ST.visitAll(Expr);
630 return HasIVParams.hasIVParams();
631 }
632
hasScalarDepsInsideRegion(const SCEV * Expr,const Region * R,llvm::Loop * Scope,bool AllowLoops,const InvariantLoadsSetTy & ILS)633 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
634 llvm::Loop *Scope, bool AllowLoops,
635 const InvariantLoadsSetTy &ILS) {
636 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
637 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
638 ST.visitAll(Expr);
639 return InRegionDeps.hasDependences();
640 }
641
isAffineExpr(const Region * R,llvm::Loop * Scope,const SCEV * Expr,ScalarEvolution & SE,InvariantLoadsSetTy * ILS)642 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
643 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
644 if (isa<SCEVCouldNotCompute>(Expr))
645 return false;
646
647 SCEVValidator Validator(R, Scope, SE, ILS);
648 LLVM_DEBUG({
649 dbgs() << "\n";
650 dbgs() << "Expr: " << *Expr << "\n";
651 dbgs() << "Region: " << R->getNameStr() << "\n";
652 dbgs() << " -> ";
653 });
654
655 ValidatorResult Result = Validator.visit(Expr);
656
657 LLVM_DEBUG({
658 if (Result.isValid())
659 dbgs() << "VALID\n";
660 dbgs() << "\n";
661 });
662
663 return Result.isValid();
664 }
665
isAffineExpr(Value * V,const Region * R,Loop * Scope,ScalarEvolution & SE,ParameterSetTy & Params)666 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
667 ScalarEvolution &SE, ParameterSetTy &Params) {
668 auto *E = SE.getSCEV(V);
669 if (isa<SCEVCouldNotCompute>(E))
670 return false;
671
672 SCEVValidator Validator(R, Scope, SE, nullptr);
673 ValidatorResult Result = Validator.visit(E);
674 if (!Result.isValid())
675 return false;
676
677 auto ResultParams = Result.getParameters();
678 Params.insert(ResultParams.begin(), ResultParams.end());
679
680 return true;
681 }
682
isAffineConstraint(Value * V,const Region * R,llvm::Loop * Scope,ScalarEvolution & SE,ParameterSetTy & Params,bool OrExpr)683 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
684 ScalarEvolution &SE, ParameterSetTy &Params,
685 bool OrExpr) {
686 if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
687 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
688 true) &&
689 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
690 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
691 auto Opcode = BinOp->getOpcode();
692 if (Opcode == Instruction::And || Opcode == Instruction::Or)
693 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
694 false) &&
695 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
696 false);
697 /* Fall through */
698 }
699
700 if (!OrExpr)
701 return false;
702
703 return isAffineExpr(V, R, Scope, SE, Params);
704 }
705
getParamsInAffineExpr(const Region * R,Loop * Scope,const SCEV * Expr,ScalarEvolution & SE)706 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
707 const SCEV *Expr, ScalarEvolution &SE) {
708 if (isa<SCEVCouldNotCompute>(Expr))
709 return ParameterSetTy();
710
711 InvariantLoadsSetTy ILS;
712 SCEVValidator Validator(R, Scope, SE, &ILS);
713 ValidatorResult Result = Validator.visit(Expr);
714 assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
715
716 return Result.getParameters();
717 }
718
719 std::pair<const SCEVConstant *, const SCEV *>
extractConstantFactor(const SCEV * S,ScalarEvolution & SE)720 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
721 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
722
723 if (auto *Constant = dyn_cast<SCEVConstant>(S))
724 return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
725
726 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
727 if (AddRec) {
728 auto *StartExpr = AddRec->getStart();
729 if (StartExpr->isZero()) {
730 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
731 auto *LeftOverAddRec =
732 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
733 AddRec->getNoWrapFlags());
734 return std::make_pair(StepPair.first, LeftOverAddRec);
735 }
736 return std::make_pair(ConstPart, S);
737 }
738
739 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
740 SmallVector<const SCEV *, 4> LeftOvers;
741 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
742 auto *Factor = Op0Pair.first;
743 if (SE.isKnownNegative(Factor)) {
744 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
745 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
746 } else {
747 LeftOvers.push_back(Op0Pair.second);
748 }
749
750 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
751 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
752 // TODO: Use something smarter than equality here, e.g., gcd.
753 if (Factor == OpUPair.first)
754 LeftOvers.push_back(OpUPair.second);
755 else if (Factor == SE.getNegativeSCEV(OpUPair.first))
756 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
757 else
758 return std::make_pair(ConstPart, S);
759 }
760
761 auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
762 return std::make_pair(Factor, NewAdd);
763 }
764
765 auto *Mul = dyn_cast<SCEVMulExpr>(S);
766 if (!Mul)
767 return std::make_pair(ConstPart, S);
768
769 SmallVector<const SCEV *, 4> LeftOvers;
770 for (auto *Op : Mul->operands())
771 if (isa<SCEVConstant>(Op))
772 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
773 else
774 LeftOvers.push_back(Op);
775
776 return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
777 }
778
tryForwardThroughPHI(const SCEV * Expr,Region & R,ScalarEvolution & SE,LoopInfo & LI,const DominatorTree & DT)779 const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R,
780 ScalarEvolution &SE, LoopInfo &LI,
781 const DominatorTree &DT) {
782 if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
783 Value *V = Unknown->getValue();
784 auto *PHI = dyn_cast<PHINode>(V);
785 if (!PHI)
786 return Expr;
787
788 Value *Final = nullptr;
789
790 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
791 BasicBlock *Incoming = PHI->getIncomingBlock(i);
792 if (isErrorBlock(*Incoming, R, LI, DT) && R.contains(Incoming))
793 continue;
794 if (Final)
795 return Expr;
796 Final = PHI->getIncomingValue(i);
797 }
798
799 if (Final)
800 return SE.getSCEV(Final);
801 }
802 return Expr;
803 }
804
getUniqueNonErrorValue(PHINode * PHI,Region * R,LoopInfo & LI,const DominatorTree & DT)805 Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, LoopInfo &LI,
806 const DominatorTree &DT) {
807 Value *V = nullptr;
808 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
809 BasicBlock *BB = PHI->getIncomingBlock(i);
810 if (!isErrorBlock(*BB, *R, LI, DT)) {
811 if (V)
812 return nullptr;
813 V = PHI->getIncomingValue(i);
814 }
815 }
816
817 return V;
818 }
819 } // namespace polly
820