1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineExpr.h"
10 #include "AffineExprDetail.h"
11 #include "mlir/IR/AffineExprVisitor.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/Support/MathExtras.h"
15 #include "mlir/Support/TypeID.h"
16 #include "llvm/ADT/STLExtras.h"
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
getContext() const21 MLIRContext *AffineExpr::getContext() const { return expr->context; }
22 
getKind() const23 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
24 
25 /// Walk all of the AffineExprs in this subgraph in postorder.
walk(std::function<void (AffineExpr)> callback) const26 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
27   struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
28     std::function<void(AffineExpr)> callback;
29 
30     AffineExprWalker(std::function<void(AffineExpr)> callback)
31         : callback(callback) {}
32 
33     void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
34     void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
35     void visitDimExpr(AffineDimExpr expr) { callback(expr); }
36     void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
37   };
38 
39   AffineExprWalker(callback).walkPostOrder(*this);
40 }
41 
42 // Dispatch affine expression construction based on kind.
getAffineBinaryOpExpr(AffineExprKind kind,AffineExpr lhs,AffineExpr rhs)43 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
44                                        AffineExpr rhs) {
45   if (kind == AffineExprKind::Add)
46     return lhs + rhs;
47   if (kind == AffineExprKind::Mul)
48     return lhs * rhs;
49   if (kind == AffineExprKind::FloorDiv)
50     return lhs.floorDiv(rhs);
51   if (kind == AffineExprKind::CeilDiv)
52     return lhs.ceilDiv(rhs);
53   if (kind == AffineExprKind::Mod)
54     return lhs % rhs;
55 
56   llvm_unreachable("unknown binary operation on affine expressions");
57 }
58 
59 /// This method substitutes any uses of dimensions and symbols (e.g.
60 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
61 AffineExpr
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements) const62 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
63                                   ArrayRef<AffineExpr> symReplacements) const {
64   switch (getKind()) {
65   case AffineExprKind::Constant:
66     return *this;
67   case AffineExprKind::DimId: {
68     unsigned dimId = cast<AffineDimExpr>().getPosition();
69     if (dimId >= dimReplacements.size())
70       return *this;
71     return dimReplacements[dimId];
72   }
73   case AffineExprKind::SymbolId: {
74     unsigned symId = cast<AffineSymbolExpr>().getPosition();
75     if (symId >= symReplacements.size())
76       return *this;
77     return symReplacements[symId];
78   }
79   case AffineExprKind::Add:
80   case AffineExprKind::Mul:
81   case AffineExprKind::FloorDiv:
82   case AffineExprKind::CeilDiv:
83   case AffineExprKind::Mod:
84     auto binOp = cast<AffineBinaryOpExpr>();
85     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
86     auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
87     auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
88     if (newLHS == lhs && newRHS == rhs)
89       return *this;
90     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
91   }
92   llvm_unreachable("Unknown AffineExpr");
93 }
94 
replaceDims(ArrayRef<AffineExpr> dimReplacements) const95 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
96   return replaceDimsAndSymbols(dimReplacements, {});
97 }
98 
99 AffineExpr
replaceSymbols(ArrayRef<AffineExpr> symReplacements) const100 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
101   return replaceDimsAndSymbols({}, symReplacements);
102 }
103 
104 /// Replace dims[offset ... numDims)
105 /// by dims[offset + shift ... shift + numDims).
shiftDims(unsigned numDims,unsigned shift,unsigned offset) const106 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
107                                  unsigned offset) const {
108   SmallVector<AffineExpr, 4> dims;
109   for (unsigned idx = 0; idx < offset; ++idx)
110     dims.push_back(getAffineDimExpr(idx, getContext()));
111   for (unsigned idx = offset; idx < numDims; ++idx)
112     dims.push_back(getAffineDimExpr(idx + shift, getContext()));
113   return replaceDimsAndSymbols(dims, {});
114 }
115 
116 /// Replace symbols[offset ... numSymbols)
117 /// by symbols[offset + shift ... shift + numSymbols).
shiftSymbols(unsigned numSymbols,unsigned shift,unsigned offset) const118 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
119                                     unsigned offset) const {
120   SmallVector<AffineExpr, 4> symbols;
121   for (unsigned idx = 0; idx < offset; ++idx)
122     symbols.push_back(getAffineSymbolExpr(idx, getContext()));
123   for (unsigned idx = offset; idx < numSymbols; ++idx)
124     symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
125   return replaceDimsAndSymbols({}, symbols);
126 }
127 
128 /// Sparse replace method. Return the modified expression tree.
129 AffineExpr
replace(const DenseMap<AffineExpr,AffineExpr> & map) const130 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
131   auto it = map.find(*this);
132   if (it != map.end())
133     return it->second;
134   switch (getKind()) {
135   default:
136     return *this;
137   case AffineExprKind::Add:
138   case AffineExprKind::Mul:
139   case AffineExprKind::FloorDiv:
140   case AffineExprKind::CeilDiv:
141   case AffineExprKind::Mod:
142     auto binOp = cast<AffineBinaryOpExpr>();
143     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
144     auto newLHS = lhs.replace(map);
145     auto newRHS = rhs.replace(map);
146     if (newLHS == lhs && newRHS == rhs)
147       return *this;
148     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
149   }
150   llvm_unreachable("Unknown AffineExpr");
151 }
152 
153 /// Sparse replace method. Return the modified expression tree.
replace(AffineExpr expr,AffineExpr replacement) const154 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
155   DenseMap<AffineExpr, AffineExpr> map;
156   map.insert(std::make_pair(expr, replacement));
157   return replace(map);
158 }
159 /// Returns true if this expression is made out of only symbols and
160 /// constants (no dimensional identifiers).
isSymbolicOrConstant() const161 bool AffineExpr::isSymbolicOrConstant() const {
162   switch (getKind()) {
163   case AffineExprKind::Constant:
164     return true;
165   case AffineExprKind::DimId:
166     return false;
167   case AffineExprKind::SymbolId:
168     return true;
169 
170   case AffineExprKind::Add:
171   case AffineExprKind::Mul:
172   case AffineExprKind::FloorDiv:
173   case AffineExprKind::CeilDiv:
174   case AffineExprKind::Mod: {
175     auto expr = this->cast<AffineBinaryOpExpr>();
176     return expr.getLHS().isSymbolicOrConstant() &&
177            expr.getRHS().isSymbolicOrConstant();
178   }
179   }
180   llvm_unreachable("Unknown AffineExpr");
181 }
182 
183 /// Returns true if this is a pure affine expression, i.e., multiplication,
184 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
isPureAffine() const185 bool AffineExpr::isPureAffine() const {
186   switch (getKind()) {
187   case AffineExprKind::SymbolId:
188   case AffineExprKind::DimId:
189   case AffineExprKind::Constant:
190     return true;
191   case AffineExprKind::Add: {
192     auto op = cast<AffineBinaryOpExpr>();
193     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
194   }
195 
196   case AffineExprKind::Mul: {
197     // TODO: Canonicalize the constants in binary operators to the RHS when
198     // possible, allowing this to merge into the next case.
199     auto op = cast<AffineBinaryOpExpr>();
200     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
201            (op.getLHS().template isa<AffineConstantExpr>() ||
202             op.getRHS().template isa<AffineConstantExpr>());
203   }
204   case AffineExprKind::FloorDiv:
205   case AffineExprKind::CeilDiv:
206   case AffineExprKind::Mod: {
207     auto op = cast<AffineBinaryOpExpr>();
208     return op.getLHS().isPureAffine() &&
209            op.getRHS().template isa<AffineConstantExpr>();
210   }
211   }
212   llvm_unreachable("Unknown AffineExpr");
213 }
214 
215 // Returns the greatest known integral divisor of this affine expression.
getLargestKnownDivisor() const216 int64_t AffineExpr::getLargestKnownDivisor() const {
217   AffineBinaryOpExpr binExpr(nullptr);
218   switch (getKind()) {
219   case AffineExprKind::SymbolId:
220     LLVM_FALLTHROUGH;
221   case AffineExprKind::DimId:
222     return 1;
223   case AffineExprKind::Constant:
224     return std::abs(this->cast<AffineConstantExpr>().getValue());
225   case AffineExprKind::Mul: {
226     binExpr = this->cast<AffineBinaryOpExpr>();
227     return binExpr.getLHS().getLargestKnownDivisor() *
228            binExpr.getRHS().getLargestKnownDivisor();
229   }
230   case AffineExprKind::Add:
231     LLVM_FALLTHROUGH;
232   case AffineExprKind::FloorDiv:
233   case AffineExprKind::CeilDiv:
234   case AffineExprKind::Mod: {
235     binExpr = cast<AffineBinaryOpExpr>();
236     return llvm::GreatestCommonDivisor64(
237         binExpr.getLHS().getLargestKnownDivisor(),
238         binExpr.getRHS().getLargestKnownDivisor());
239   }
240   }
241   llvm_unreachable("Unknown AffineExpr");
242 }
243 
isMultipleOf(int64_t factor) const244 bool AffineExpr::isMultipleOf(int64_t factor) const {
245   AffineBinaryOpExpr binExpr(nullptr);
246   uint64_t l, u;
247   switch (getKind()) {
248   case AffineExprKind::SymbolId:
249     LLVM_FALLTHROUGH;
250   case AffineExprKind::DimId:
251     return factor * factor == 1;
252   case AffineExprKind::Constant:
253     return cast<AffineConstantExpr>().getValue() % factor == 0;
254   case AffineExprKind::Mul: {
255     binExpr = cast<AffineBinaryOpExpr>();
256     // It's probably not worth optimizing this further (to not traverse the
257     // whole sub-tree under - it that would require a version of isMultipleOf
258     // that on a 'false' return also returns the largest known divisor).
259     return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
260            (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
261            (l * u) % factor == 0;
262   }
263   case AffineExprKind::Add:
264   case AffineExprKind::FloorDiv:
265   case AffineExprKind::CeilDiv:
266   case AffineExprKind::Mod: {
267     binExpr = cast<AffineBinaryOpExpr>();
268     return llvm::GreatestCommonDivisor64(
269                binExpr.getLHS().getLargestKnownDivisor(),
270                binExpr.getRHS().getLargestKnownDivisor()) %
271                factor ==
272            0;
273   }
274   }
275   llvm_unreachable("Unknown AffineExpr");
276 }
277 
isFunctionOfDim(unsigned position) const278 bool AffineExpr::isFunctionOfDim(unsigned position) const {
279   if (getKind() == AffineExprKind::DimId) {
280     return *this == mlir::getAffineDimExpr(position, getContext());
281   }
282   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
283     return expr.getLHS().isFunctionOfDim(position) ||
284            expr.getRHS().isFunctionOfDim(position);
285   }
286   return false;
287 }
288 
isFunctionOfSymbol(unsigned position) const289 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
290   if (getKind() == AffineExprKind::SymbolId) {
291     return *this == mlir::getAffineSymbolExpr(position, getContext());
292   }
293   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
294     return expr.getLHS().isFunctionOfSymbol(position) ||
295            expr.getRHS().isFunctionOfSymbol(position);
296   }
297   return false;
298 }
299 
AffineBinaryOpExpr(AffineExpr::ImplType * ptr)300 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
301     : AffineExpr(ptr) {}
getLHS() const302 AffineExpr AffineBinaryOpExpr::getLHS() const {
303   return static_cast<ImplType *>(expr)->lhs;
304 }
getRHS() const305 AffineExpr AffineBinaryOpExpr::getRHS() const {
306   return static_cast<ImplType *>(expr)->rhs;
307 }
308 
AffineDimExpr(AffineExpr::ImplType * ptr)309 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
getPosition() const310 unsigned AffineDimExpr::getPosition() const {
311   return static_cast<ImplType *>(expr)->position;
312 }
313 
314 /// Returns true if the expression is divisible by the given symbol with
315 /// position `symbolPos`. The argument `opKind` specifies here what kind of
316 /// division or mod operation called this division. It helps in implementing the
317 /// commutative property of the floordiv and ceildiv operations. If the argument
318 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
319 /// operation, then the commutative property can be used otherwise, the floordiv
320 /// operation is not divisible. The same argument holds for ceildiv operation.
isDivisibleBySymbol(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)321 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
322                                 AffineExprKind opKind) {
323   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
324   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
325           opKind == AffineExprKind::CeilDiv) &&
326          "unexpected opKind");
327   switch (expr.getKind()) {
328   case AffineExprKind::Constant:
329     if (expr.cast<AffineConstantExpr>().getValue())
330       return false;
331     return true;
332   case AffineExprKind::DimId:
333     return false;
334   case AffineExprKind::SymbolId:
335     return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
336   // Checks divisibility by the given symbol for both operands.
337   case AffineExprKind::Add: {
338     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
339     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
340            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
341   }
342   // Checks divisibility by the given symbol for both operands. Consider the
343   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
344   // this is a division by s1 and both the operands of modulo are divisible by
345   // s1 but it is not divisible by s1 always. The third argument is
346   // `AffineExprKind::Mod` for this reason.
347   case AffineExprKind::Mod: {
348     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
349     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
350                                AffineExprKind::Mod) &&
351            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
352                                AffineExprKind::Mod);
353   }
354   // Checks if any of the operand divisible by the given symbol.
355   case AffineExprKind::Mul: {
356     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
357     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
358            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
359   }
360   // Floordiv and ceildiv are divisible by the given symbol when the first
361   // operand is divisible, and the affine expression kind of the argument expr
362   // is same as the argument `opKind`. This can be inferred from commutative
363   // property of floordiv and ceildiv operations and are as follow:
364   // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
365   // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
366   // It will fail if operations are not same. For example:
367   // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
368   case AffineExprKind::FloorDiv:
369   case AffineExprKind::CeilDiv: {
370     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
371     if (opKind != expr.getKind())
372       return false;
373     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
374   }
375   }
376   llvm_unreachable("Unknown AffineExpr");
377 }
378 
379 /// Divides the given expression by the given symbol at position `symbolPos`. It
380 /// considers the divisibility condition is checked before calling itself. A
381 /// null expression is returned whenever the divisibility condition fails.
symbolicDivide(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)382 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
383                                  AffineExprKind opKind) {
384   // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
385   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
386           opKind == AffineExprKind::CeilDiv) &&
387          "unexpected opKind");
388   switch (expr.getKind()) {
389   case AffineExprKind::Constant:
390     if (expr.cast<AffineConstantExpr>().getValue() != 0)
391       return nullptr;
392     return getAffineConstantExpr(0, expr.getContext());
393   case AffineExprKind::DimId:
394     return nullptr;
395   case AffineExprKind::SymbolId:
396     return getAffineConstantExpr(1, expr.getContext());
397   // Dividing both operands by the given symbol.
398   case AffineExprKind::Add: {
399     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
400     return getAffineBinaryOpExpr(
401         expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
402         symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
403   }
404   // Dividing both operands by the given symbol.
405   case AffineExprKind::Mod: {
406     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
407     return getAffineBinaryOpExpr(
408         expr.getKind(),
409         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
410         symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
411   }
412   // Dividing any of the operand by the given symbol.
413   case AffineExprKind::Mul: {
414     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
415     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
416       return binaryExpr.getLHS() *
417              symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
418     return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
419            binaryExpr.getRHS();
420   }
421   // Dividing first operand only by the given symbol.
422   case AffineExprKind::FloorDiv:
423   case AffineExprKind::CeilDiv: {
424     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
425     return getAffineBinaryOpExpr(
426         expr.getKind(),
427         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
428         binaryExpr.getRHS());
429   }
430   }
431   llvm_unreachable("Unknown AffineExpr");
432 }
433 
434 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
435 /// operations when the second operand simplifies to a symbol and the first
436 /// operand is divisible by that symbol. It can be applied to any semi-affine
437 /// expression. Returned expression can either be a semi-affine or pure affine
438 /// expression.
simplifySemiAffine(AffineExpr expr)439 static AffineExpr simplifySemiAffine(AffineExpr expr) {
440   switch (expr.getKind()) {
441   case AffineExprKind::Constant:
442   case AffineExprKind::DimId:
443   case AffineExprKind::SymbolId:
444     return expr;
445   case AffineExprKind::Add:
446   case AffineExprKind::Mul: {
447     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
448     return getAffineBinaryOpExpr(expr.getKind(),
449                                  simplifySemiAffine(binaryExpr.getLHS()),
450                                  simplifySemiAffine(binaryExpr.getRHS()));
451   }
452   // Check if the simplification of the second operand is a symbol, and the
453   // first operand is divisible by it. If the operation is a modulo, a constant
454   // zero expression is returned. In the case of floordiv and ceildiv, the
455   // symbol from the simplification of the second operand divides the first
456   // operand. Otherwise, simplification is not possible.
457   case AffineExprKind::FloorDiv:
458   case AffineExprKind::CeilDiv:
459   case AffineExprKind::Mod: {
460     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
461     AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
462     AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
463     AffineSymbolExpr symbolExpr =
464         simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
465     if (!symbolExpr)
466       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
467     unsigned symbolPos = symbolExpr.getPosition();
468     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
469       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
470     if (expr.getKind() == AffineExprKind::Mod)
471       return getAffineConstantExpr(0, expr.getContext());
472     return symbolicDivide(sLHS, symbolPos, expr.getKind());
473   }
474   }
475   llvm_unreachable("Unknown AffineExpr");
476 }
477 
getAffineDimOrSymbol(AffineExprKind kind,unsigned position,MLIRContext * context)478 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
479                                        MLIRContext *context) {
480   auto assignCtx = [context](AffineDimExprStorage *storage) {
481     storage->context = context;
482   };
483 
484   StorageUniquer &uniquer = context->getAffineUniquer();
485   return uniquer.get<AffineDimExprStorage>(
486       assignCtx, static_cast<unsigned>(kind), position);
487 }
488 
getAffineDimExpr(unsigned position,MLIRContext * context)489 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
490   return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
491 }
492 
AffineSymbolExpr(AffineExpr::ImplType * ptr)493 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
494     : AffineExpr(ptr) {}
getPosition() const495 unsigned AffineSymbolExpr::getPosition() const {
496   return static_cast<ImplType *>(expr)->position;
497 }
498 
getAffineSymbolExpr(unsigned position,MLIRContext * context)499 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
500   return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
501   ;
502 }
503 
AffineConstantExpr(AffineExpr::ImplType * ptr)504 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
505     : AffineExpr(ptr) {}
getValue() const506 int64_t AffineConstantExpr::getValue() const {
507   return static_cast<ImplType *>(expr)->constant;
508 }
509 
operator ==(int64_t v) const510 bool AffineExpr::operator==(int64_t v) const {
511   return *this == getAffineConstantExpr(v, getContext());
512 }
513 
getAffineConstantExpr(int64_t constant,MLIRContext * context)514 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
515   auto assignCtx = [context](AffineConstantExprStorage *storage) {
516     storage->context = context;
517   };
518 
519   StorageUniquer &uniquer = context->getAffineUniquer();
520   return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
521 }
522 
523 /// Simplify add expression. Return nullptr if it can't be simplified.
simplifyAdd(AffineExpr lhs,AffineExpr rhs)524 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
525   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
526   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
527   // Fold if both LHS, RHS are a constant.
528   if (lhsConst && rhsConst)
529     return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
530                                  lhs.getContext());
531 
532   // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
533   // If only one of them is a symbolic expressions, make it the RHS.
534   if (lhs.isa<AffineConstantExpr>() ||
535       (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
536     return rhs + lhs;
537   }
538 
539   // At this point, if there was a constant, it would be on the right.
540 
541   // Addition with a zero is a noop, return the other input.
542   if (rhsConst) {
543     if (rhsConst.getValue() == 0)
544       return lhs;
545   }
546   // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
547   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
548   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
549     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
550       return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
551   }
552 
553   // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
554   // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
555   // respective multiplicands.
556   Optional<int64_t> rLhsConst, rRhsConst;
557   AffineExpr firstExpr, secondExpr;
558   AffineConstantExpr rLhsConstExpr;
559   auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
560   if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
561       (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
562     rLhsConst = rLhsConstExpr.getValue();
563     firstExpr = lBinOpExpr.getLHS();
564   } else {
565     rLhsConst = 1;
566     firstExpr = lhs;
567   }
568 
569   auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
570   AffineConstantExpr rRhsConstExpr;
571   if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
572       (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
573     rRhsConst = rRhsConstExpr.getValue();
574     secondExpr = rBinOpExpr.getLHS();
575   } else {
576     rRhsConst = 1;
577     secondExpr = rhs;
578   }
579 
580   if (rLhsConst && rRhsConst && firstExpr == secondExpr)
581     return getAffineBinaryOpExpr(
582         AffineExprKind::Mul, firstExpr,
583         getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(),
584                               lhs.getContext()));
585 
586   // When doing successive additions, bring constant to the right: turn (d0 + 2)
587   // + d1 into (d0 + d1) + 2.
588   if (lBin && lBin.getKind() == AffineExprKind::Add) {
589     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
590       return lBin.getLHS() + rhs + lrhs;
591     }
592   }
593 
594   // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
595   // leads to a much more efficient form when 'c' is a power of two, and in
596   // general a more compact and readable form.
597 
598   // Process '(expr floordiv c) * (-c)'.
599   if (!rBinOpExpr)
600     return nullptr;
601 
602   auto lrhs = rBinOpExpr.getLHS();
603   auto rrhs = rBinOpExpr.getRHS();
604 
605   // Process lrhs, which is 'expr floordiv c'.
606   AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
607   if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
608     return nullptr;
609 
610   auto llrhs = lrBinOpExpr.getLHS();
611   auto rlrhs = lrBinOpExpr.getRHS();
612 
613   if (lhs == llrhs && rlrhs == -rrhs) {
614     return lhs % rlrhs;
615   }
616   return nullptr;
617 }
618 
operator +(int64_t v) const619 AffineExpr AffineExpr::operator+(int64_t v) const {
620   return *this + getAffineConstantExpr(v, getContext());
621 }
operator +(AffineExpr other) const622 AffineExpr AffineExpr::operator+(AffineExpr other) const {
623   if (auto simplified = simplifyAdd(*this, other))
624     return simplified;
625 
626   StorageUniquer &uniquer = getContext()->getAffineUniquer();
627   return uniquer.get<AffineBinaryOpExprStorage>(
628       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
629 }
630 
631 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
simplifyMul(AffineExpr lhs,AffineExpr rhs)632 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
633   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
634   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
635 
636   if (lhsConst && rhsConst)
637     return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
638                                  lhs.getContext());
639 
640   assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
641 
642   // Canonicalize the mul expression so that the constant/symbolic term is the
643   // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
644   // constant. (Note that a constant is trivially symbolic).
645   if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
646     // At least one of them has to be symbolic.
647     return rhs * lhs;
648   }
649 
650   // At this point, if there was a constant, it would be on the right.
651 
652   // Multiplication with a one is a noop, return the other input.
653   if (rhsConst) {
654     if (rhsConst.getValue() == 1)
655       return lhs;
656     // Multiplication with zero.
657     if (rhsConst.getValue() == 0)
658       return rhsConst;
659   }
660 
661   // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
662   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
663   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
664     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
665       return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
666   }
667 
668   // When doing successive multiplication, bring constant to the right: turn (d0
669   // * 2) * d1 into (d0 * d1) * 2.
670   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
671     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
672       return (lBin.getLHS() * rhs) * lrhs;
673     }
674   }
675 
676   return nullptr;
677 }
678 
operator *(int64_t v) const679 AffineExpr AffineExpr::operator*(int64_t v) const {
680   return *this * getAffineConstantExpr(v, getContext());
681 }
operator *(AffineExpr other) const682 AffineExpr AffineExpr::operator*(AffineExpr other) const {
683   if (auto simplified = simplifyMul(*this, other))
684     return simplified;
685 
686   StorageUniquer &uniquer = getContext()->getAffineUniquer();
687   return uniquer.get<AffineBinaryOpExprStorage>(
688       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
689 }
690 
691 // Unary minus, delegate to operator*.
operator -() const692 AffineExpr AffineExpr::operator-() const {
693   return *this * getAffineConstantExpr(-1, getContext());
694 }
695 
696 // Delegate to operator+.
operator -(int64_t v) const697 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
operator -(AffineExpr other) const698 AffineExpr AffineExpr::operator-(AffineExpr other) const {
699   return *this + (-other);
700 }
701 
simplifyFloorDiv(AffineExpr lhs,AffineExpr rhs)702 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
703   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
704   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
705 
706   // mlir floordiv by zero or negative numbers is undefined and preserved as is.
707   if (!rhsConst || rhsConst.getValue() < 1)
708     return nullptr;
709 
710   if (lhsConst)
711     return getAffineConstantExpr(
712         floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
713 
714   // Fold floordiv of a multiply with a constant that is a multiple of the
715   // divisor. Eg: (i * 128) floordiv 64 = i * 2.
716   if (rhsConst == 1)
717     return lhs;
718 
719   // Simplify (expr * const) floordiv divConst when expr is known to be a
720   // multiple of divConst.
721   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
722   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
723     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
724       // rhsConst is known to be a positive constant.
725       if (lrhs.getValue() % rhsConst.getValue() == 0)
726         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
727     }
728   }
729 
730   // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
731   // known to be a multiple of divConst.
732   if (lBin && lBin.getKind() == AffineExprKind::Add) {
733     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
734     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
735     // rhsConst is known to be a positive constant.
736     if (llhsDiv % rhsConst.getValue() == 0 ||
737         lrhsDiv % rhsConst.getValue() == 0)
738       return lBin.getLHS().floorDiv(rhsConst.getValue()) +
739              lBin.getRHS().floorDiv(rhsConst.getValue());
740   }
741 
742   return nullptr;
743 }
744 
floorDiv(uint64_t v) const745 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
746   return floorDiv(getAffineConstantExpr(v, getContext()));
747 }
floorDiv(AffineExpr other) const748 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
749   if (auto simplified = simplifyFloorDiv(*this, other))
750     return simplified;
751 
752   StorageUniquer &uniquer = getContext()->getAffineUniquer();
753   return uniquer.get<AffineBinaryOpExprStorage>(
754       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
755       other);
756 }
757 
simplifyCeilDiv(AffineExpr lhs,AffineExpr rhs)758 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
759   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
760   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
761 
762   if (!rhsConst || rhsConst.getValue() < 1)
763     return nullptr;
764 
765   if (lhsConst)
766     return getAffineConstantExpr(
767         ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
768 
769   // Fold ceildiv of a multiply with a constant that is a multiple of the
770   // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
771   if (rhsConst.getValue() == 1)
772     return lhs;
773 
774   // Simplify (expr * const) ceildiv divConst when const is known to be a
775   // multiple of divConst.
776   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
777   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
778     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
779       // rhsConst is known to be a positive constant.
780       if (lrhs.getValue() % rhsConst.getValue() == 0)
781         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
782     }
783   }
784 
785   return nullptr;
786 }
787 
ceilDiv(uint64_t v) const788 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
789   return ceilDiv(getAffineConstantExpr(v, getContext()));
790 }
ceilDiv(AffineExpr other) const791 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
792   if (auto simplified = simplifyCeilDiv(*this, other))
793     return simplified;
794 
795   StorageUniquer &uniquer = getContext()->getAffineUniquer();
796   return uniquer.get<AffineBinaryOpExprStorage>(
797       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
798       other);
799 }
800 
simplifyMod(AffineExpr lhs,AffineExpr rhs)801 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
802   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
803   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
804 
805   // mod w.r.t zero or negative numbers is undefined and preserved as is.
806   if (!rhsConst || rhsConst.getValue() < 1)
807     return nullptr;
808 
809   if (lhsConst)
810     return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
811                                  lhs.getContext());
812 
813   // Fold modulo of an expression that is known to be a multiple of a constant
814   // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
815   // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
816   if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
817     return getAffineConstantExpr(0, lhs.getContext());
818 
819   // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
820   // known to be a multiple of divConst.
821   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
822   if (lBin && lBin.getKind() == AffineExprKind::Add) {
823     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
824     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
825     // rhsConst is known to be a positive constant.
826     if (llhsDiv % rhsConst.getValue() == 0)
827       return lBin.getRHS() % rhsConst.getValue();
828     if (lrhsDiv % rhsConst.getValue() == 0)
829       return lBin.getLHS() % rhsConst.getValue();
830   }
831 
832   // Simplify (e % a) % b to e % b when b evenly divides a
833   if (lBin && lBin.getKind() == AffineExprKind::Mod) {
834     auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
835     if (intermediate && intermediate.getValue() >= 1 &&
836         mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
837       return lBin.getLHS() % rhsConst.getValue();
838     }
839   }
840 
841   return nullptr;
842 }
843 
operator %(uint64_t v) const844 AffineExpr AffineExpr::operator%(uint64_t v) const {
845   return *this % getAffineConstantExpr(v, getContext());
846 }
operator %(AffineExpr other) const847 AffineExpr AffineExpr::operator%(AffineExpr other) const {
848   if (auto simplified = simplifyMod(*this, other))
849     return simplified;
850 
851   StorageUniquer &uniquer = getContext()->getAffineUniquer();
852   return uniquer.get<AffineBinaryOpExprStorage>(
853       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
854 }
855 
compose(AffineMap map) const856 AffineExpr AffineExpr::compose(AffineMap map) const {
857   SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
858                                              map.getResults().end());
859   return replaceDimsAndSymbols(dimReplacements, {});
860 }
operator <<(raw_ostream & os,AffineExpr expr)861 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
862   expr.print(os);
863   return os;
864 }
865 
866 /// Constructs an affine expression from a flat ArrayRef. If there are local
867 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
868 /// products expression, `localExprs` is expected to have the AffineExpr
869 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
870 /// in the format [dims, symbols, locals, constant term].
getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,unsigned numDims,unsigned numSymbols,ArrayRef<AffineExpr> localExprs,MLIRContext * context)871 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
872                                            unsigned numDims,
873                                            unsigned numSymbols,
874                                            ArrayRef<AffineExpr> localExprs,
875                                            MLIRContext *context) {
876   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
877   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
878          "unexpected number of local expressions");
879 
880   auto expr = getAffineConstantExpr(0, context);
881   // Dimensions and symbols.
882   for (unsigned j = 0; j < numDims + numSymbols; j++) {
883     if (flatExprs[j] == 0)
884       continue;
885     auto id = j < numDims ? getAffineDimExpr(j, context)
886                           : getAffineSymbolExpr(j - numDims, context);
887     expr = expr + id * flatExprs[j];
888   }
889 
890   // Local identifiers.
891   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
892        j++) {
893     if (flatExprs[j] == 0)
894       continue;
895     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
896     expr = expr + term;
897   }
898 
899   // Constant term.
900   int64_t constTerm = flatExprs[flatExprs.size() - 1];
901   if (constTerm != 0)
902     expr = expr + constTerm;
903   return expr;
904 }
905 
SimpleAffineExprFlattener(unsigned numDims,unsigned numSymbols)906 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
907                                                      unsigned numSymbols)
908     : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
909   operandExprStack.reserve(8);
910 }
911 
visitMulExpr(AffineBinaryOpExpr expr)912 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
913   assert(operandExprStack.size() >= 2);
914   // This is a pure affine expr; the RHS will be a constant.
915   assert(expr.getRHS().isa<AffineConstantExpr>());
916   // Get the RHS constant.
917   auto rhsConst = operandExprStack.back()[getConstantIndex()];
918   operandExprStack.pop_back();
919   // Update the LHS in place instead of pop and push.
920   auto &lhs = operandExprStack.back();
921   for (unsigned i = 0, e = lhs.size(); i < e; i++) {
922     lhs[i] *= rhsConst;
923   }
924 }
925 
visitAddExpr(AffineBinaryOpExpr expr)926 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
927   assert(operandExprStack.size() >= 2);
928   const auto &rhs = operandExprStack.back();
929   auto &lhs = operandExprStack[operandExprStack.size() - 2];
930   assert(lhs.size() == rhs.size());
931   // Update the LHS in place.
932   for (unsigned i = 0, e = rhs.size(); i < e; i++) {
933     lhs[i] += rhs[i];
934   }
935   // Pop off the RHS.
936   operandExprStack.pop_back();
937 }
938 
939 //
940 // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
941 //
942 // A mod expression "expr mod c" is thus flattened by introducing a new local
943 // variable q (= expr floordiv c), such that expr mod c is replaced with
944 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
visitModExpr(AffineBinaryOpExpr expr)945 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
946   assert(operandExprStack.size() >= 2);
947   // This is a pure affine expr; the RHS will be a constant.
948   assert(expr.getRHS().isa<AffineConstantExpr>());
949   auto rhsConst = operandExprStack.back()[getConstantIndex()];
950   operandExprStack.pop_back();
951   auto &lhs = operandExprStack.back();
952   // TODO: handle modulo by zero case when this issue is fixed
953   // at the other places in the IR.
954   assert(rhsConst > 0 && "RHS constant has to be positive");
955 
956   // Check if the LHS expression is a multiple of modulo factor.
957   unsigned i, e;
958   for (i = 0, e = lhs.size(); i < e; i++)
959     if (lhs[i] % rhsConst != 0)
960       break;
961   // If yes, modulo expression here simplifies to zero.
962   if (i == lhs.size()) {
963     std::fill(lhs.begin(), lhs.end(), 0);
964     return;
965   }
966 
967   // Add a local variable for the quotient, i.e., expr % c is replaced by
968   // (expr - q * c) where q = expr floordiv c. Do this while canceling out
969   // the GCD of expr and c.
970   SmallVector<int64_t, 8> floorDividend(lhs);
971   uint64_t gcd = rhsConst;
972   for (unsigned i = 0, e = lhs.size(); i < e; i++)
973     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
974   // Simplify the numerator and the denominator.
975   if (gcd != 1) {
976     for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
977       floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
978   }
979   int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
980 
981   // Construct the AffineExpr form of the floordiv to store in localExprs.
982   MLIRContext *context = expr.getContext();
983   auto dividendExpr = getAffineExprFromFlatForm(
984       floorDividend, numDims, numSymbols, localExprs, context);
985   auto divisorExpr = getAffineConstantExpr(floorDivisor, context);
986   auto floorDivExpr = dividendExpr.floorDiv(divisorExpr);
987   int loc;
988   if ((loc = findLocalId(floorDivExpr)) == -1) {
989     addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
990     // Set result at top of stack to "lhs - rhsConst * q".
991     lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
992   } else {
993     // Reuse the existing local id.
994     lhs[getLocalVarStartIndex() + loc] = -rhsConst;
995   }
996 }
997 
visitCeilDivExpr(AffineBinaryOpExpr expr)998 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
999   visitDivExpr(expr, /*isCeil=*/true);
1000 }
visitFloorDivExpr(AffineBinaryOpExpr expr)1001 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1002   visitDivExpr(expr, /*isCeil=*/false);
1003 }
1004 
visitDimExpr(AffineDimExpr expr)1005 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1006   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1007   auto &eq = operandExprStack.back();
1008   assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1009   eq[getDimStartIndex() + expr.getPosition()] = 1;
1010 }
1011 
visitSymbolExpr(AffineSymbolExpr expr)1012 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1013   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1014   auto &eq = operandExprStack.back();
1015   assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1016   eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1017 }
1018 
visitConstantExpr(AffineConstantExpr expr)1019 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1020   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1021   auto &eq = operandExprStack.back();
1022   eq[getConstantIndex()] = expr.getValue();
1023 }
1024 
1025 // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
1026 // A floordiv is thus flattened by introducing a new local variable q, and
1027 // replacing that expression with 'q' while adding the constraints
1028 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1029 // FlatAffineConstraints::addLocalFloorDiv).
1030 //
1031 // A ceildiv is similarly flattened:
1032 // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
visitDivExpr(AffineBinaryOpExpr expr,bool isCeil)1033 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1034                                              bool isCeil) {
1035   assert(operandExprStack.size() >= 2);
1036   assert(expr.getRHS().isa<AffineConstantExpr>());
1037 
1038   // This is a pure affine expr; the RHS is a positive constant.
1039   int64_t rhsConst = operandExprStack.back()[getConstantIndex()];
1040   // TODO: handle division by zero at the same time the issue is
1041   // fixed at other places.
1042   assert(rhsConst > 0 && "RHS constant has to be positive");
1043   operandExprStack.pop_back();
1044   auto &lhs = operandExprStack.back();
1045 
1046   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1047   // common divisors of the numerator and denominator.
1048   uint64_t gcd = std::abs(rhsConst);
1049   for (unsigned i = 0, e = lhs.size(); i < e; i++)
1050     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1051   // Simplify the numerator and the denominator.
1052   if (gcd != 1) {
1053     for (unsigned i = 0, e = lhs.size(); i < e; i++)
1054       lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1055   }
1056   int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1057   // If the divisor becomes 1, the updated LHS is the result. (The
1058   // divisor can't be negative since rhsConst is positive).
1059   if (divisor == 1)
1060     return;
1061 
1062   // If the divisor cannot be simplified to one, we will have to retain
1063   // the ceil/floor expr (simplified up until here). Add an existential
1064   // quantifier to express its result, i.e., expr1 div expr2 is replaced
1065   // by a new identifier, q.
1066   MLIRContext *context = expr.getContext();
1067   auto a =
1068       getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
1069   auto b = getAffineConstantExpr(divisor, context);
1070 
1071   int loc;
1072   auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1073   if ((loc = findLocalId(divExpr)) == -1) {
1074     if (!isCeil) {
1075       SmallVector<int64_t, 8> dividend(lhs);
1076       addLocalFloorDivId(dividend, divisor, divExpr);
1077     } else {
1078       // lhs ceildiv c <=>  (lhs + c - 1) floordiv c
1079       SmallVector<int64_t, 8> dividend(lhs);
1080       dividend.back() += divisor - 1;
1081       addLocalFloorDivId(dividend, divisor, divExpr);
1082     }
1083   }
1084   // Set the expression on stack to the local var introduced to capture the
1085   // result of the division (floor or ceil).
1086   std::fill(lhs.begin(), lhs.end(), 0);
1087   if (loc == -1)
1088     lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1089   else
1090     lhs[getLocalVarStartIndex() + loc] = 1;
1091 }
1092 
1093 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1094 // The local identifier added is always a floordiv of a pure add/mul affine
1095 // function of other identifiers, coefficients of which are specified in
1096 // dividend and with respect to a positive constant divisor. localExpr is the
1097 // simplified tree expression (AffineExpr) corresponding to the quantifier.
addLocalFloorDivId(ArrayRef<int64_t> dividend,int64_t divisor,AffineExpr localExpr)1098 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1099                                                    int64_t divisor,
1100                                                    AffineExpr localExpr) {
1101   assert(divisor > 0 && "positive constant divisor expected");
1102   for (auto &subExpr : operandExprStack)
1103     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1104   localExprs.push_back(localExpr);
1105   numLocals++;
1106   // dividend and divisor are not used here; an override of this method uses it.
1107 }
1108 
findLocalId(AffineExpr localExpr)1109 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1110   SmallVectorImpl<AffineExpr>::iterator it;
1111   if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1112     return -1;
1113   return it - localExprs.begin();
1114 }
1115 
1116 /// Simplify the affine expression by flattening it and reconstructing it.
simplifyAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols)1117 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1118                                     unsigned numSymbols) {
1119   // Simplify semi-affine expressions separately.
1120   if (!expr.isPureAffine())
1121     expr = simplifySemiAffine(expr);
1122   if (!expr.isPureAffine())
1123     return expr;
1124 
1125   SimpleAffineExprFlattener flattener(numDims, numSymbols);
1126   flattener.walkPostOrder(expr);
1127   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1128   auto simplifiedExpr =
1129       getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1130                                 flattener.localExprs, expr.getContext());
1131   flattener.operandExprStack.pop_back();
1132   assert(flattener.operandExprStack.empty());
1133 
1134   return simplifiedExpr;
1135 }
1136