1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineExpr.h"
10 #include "AffineExprDetail.h"
11 #include "mlir/IR/AffineExprVisitor.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/Support/MathExtras.h"
15 #include "mlir/Support/TypeID.h"
16 #include "llvm/ADT/STLExtras.h"
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
getContext() const21 MLIRContext *AffineExpr::getContext() const { return expr->context; }
22 
getKind() const23 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
24 
25 /// Walk all of the AffineExprs in this subgraph in postorder.
walk(std::function<void (AffineExpr)> callback) const26 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
27   struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
28     std::function<void(AffineExpr)> callback;
29 
30     AffineExprWalker(std::function<void(AffineExpr)> callback)
31         : callback(callback) {}
32 
33     void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
34     void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
35     void visitDimExpr(AffineDimExpr expr) { callback(expr); }
36     void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
37   };
38 
39   AffineExprWalker(callback).walkPostOrder(*this);
40 }
41 
42 // Dispatch affine expression construction based on kind.
getAffineBinaryOpExpr(AffineExprKind kind,AffineExpr lhs,AffineExpr rhs)43 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
44                                        AffineExpr rhs) {
45   if (kind == AffineExprKind::Add)
46     return lhs + rhs;
47   if (kind == AffineExprKind::Mul)
48     return lhs * rhs;
49   if (kind == AffineExprKind::FloorDiv)
50     return lhs.floorDiv(rhs);
51   if (kind == AffineExprKind::CeilDiv)
52     return lhs.ceilDiv(rhs);
53   if (kind == AffineExprKind::Mod)
54     return lhs % rhs;
55 
56   llvm_unreachable("unknown binary operation on affine expressions");
57 }
58 
59 /// This method substitutes any uses of dimensions and symbols (e.g.
60 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
61 AffineExpr
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements) const62 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
63                                   ArrayRef<AffineExpr> symReplacements) const {
64   switch (getKind()) {
65   case AffineExprKind::Constant:
66     return *this;
67   case AffineExprKind::DimId: {
68     unsigned dimId = cast<AffineDimExpr>().getPosition();
69     if (dimId >= dimReplacements.size())
70       return *this;
71     return dimReplacements[dimId];
72   }
73   case AffineExprKind::SymbolId: {
74     unsigned symId = cast<AffineSymbolExpr>().getPosition();
75     if (symId >= symReplacements.size())
76       return *this;
77     return symReplacements[symId];
78   }
79   case AffineExprKind::Add:
80   case AffineExprKind::Mul:
81   case AffineExprKind::FloorDiv:
82   case AffineExprKind::CeilDiv:
83   case AffineExprKind::Mod:
84     auto binOp = cast<AffineBinaryOpExpr>();
85     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
86     auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
87     auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
88     if (newLHS == lhs && newRHS == rhs)
89       return *this;
90     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
91   }
92   llvm_unreachable("Unknown AffineExpr");
93 }
94 
replaceDims(ArrayRef<AffineExpr> dimReplacements) const95 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
96   return replaceDimsAndSymbols(dimReplacements, {});
97 }
98 
99 AffineExpr
replaceSymbols(ArrayRef<AffineExpr> symReplacements) const100 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
101   return replaceDimsAndSymbols({}, symReplacements);
102 }
103 
104 /// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1].
shiftDims(unsigned numDims,unsigned shift) const105 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift) const {
106   SmallVector<AffineExpr, 4> dims;
107   for (unsigned idx = 0; idx < numDims; ++idx)
108     dims.push_back(getAffineDimExpr(idx + shift, getContext()));
109   return replaceDimsAndSymbols(dims, {});
110 }
111 
112 /// Replace symbols[0 .. numSymbols - 1] by
113 /// symbols[shift .. shift + numSymbols - 1].
shiftSymbols(unsigned numSymbols,unsigned shift) const114 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const {
115   SmallVector<AffineExpr, 4> symbols;
116   for (unsigned idx = 0; idx < numSymbols; ++idx)
117     symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
118   return replaceDimsAndSymbols({}, symbols);
119 }
120 
121 /// Sparse replace method. Return the modified expression tree.
122 AffineExpr
replace(const DenseMap<AffineExpr,AffineExpr> & map) const123 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
124   auto it = map.find(*this);
125   if (it != map.end())
126     return it->second;
127   switch (getKind()) {
128   default:
129     return *this;
130   case AffineExprKind::Add:
131   case AffineExprKind::Mul:
132   case AffineExprKind::FloorDiv:
133   case AffineExprKind::CeilDiv:
134   case AffineExprKind::Mod:
135     auto binOp = cast<AffineBinaryOpExpr>();
136     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
137     auto newLHS = lhs.replace(map);
138     auto newRHS = rhs.replace(map);
139     if (newLHS == lhs && newRHS == rhs)
140       return *this;
141     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
142   }
143   llvm_unreachable("Unknown AffineExpr");
144 }
145 
146 /// Sparse replace method. Return the modified expression tree.
replace(AffineExpr expr,AffineExpr replacement) const147 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
148   DenseMap<AffineExpr, AffineExpr> map;
149   map.insert(std::make_pair(expr, replacement));
150   return replace(map);
151 }
152 /// Returns true if this expression is made out of only symbols and
153 /// constants (no dimensional identifiers).
isSymbolicOrConstant() const154 bool AffineExpr::isSymbolicOrConstant() const {
155   switch (getKind()) {
156   case AffineExprKind::Constant:
157     return true;
158   case AffineExprKind::DimId:
159     return false;
160   case AffineExprKind::SymbolId:
161     return true;
162 
163   case AffineExprKind::Add:
164   case AffineExprKind::Mul:
165   case AffineExprKind::FloorDiv:
166   case AffineExprKind::CeilDiv:
167   case AffineExprKind::Mod: {
168     auto expr = this->cast<AffineBinaryOpExpr>();
169     return expr.getLHS().isSymbolicOrConstant() &&
170            expr.getRHS().isSymbolicOrConstant();
171   }
172   }
173   llvm_unreachable("Unknown AffineExpr");
174 }
175 
176 /// Returns true if this is a pure affine expression, i.e., multiplication,
177 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
isPureAffine() const178 bool AffineExpr::isPureAffine() const {
179   switch (getKind()) {
180   case AffineExprKind::SymbolId:
181   case AffineExprKind::DimId:
182   case AffineExprKind::Constant:
183     return true;
184   case AffineExprKind::Add: {
185     auto op = cast<AffineBinaryOpExpr>();
186     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
187   }
188 
189   case AffineExprKind::Mul: {
190     // TODO: Canonicalize the constants in binary operators to the RHS when
191     // possible, allowing this to merge into the next case.
192     auto op = cast<AffineBinaryOpExpr>();
193     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
194            (op.getLHS().template isa<AffineConstantExpr>() ||
195             op.getRHS().template isa<AffineConstantExpr>());
196   }
197   case AffineExprKind::FloorDiv:
198   case AffineExprKind::CeilDiv:
199   case AffineExprKind::Mod: {
200     auto op = cast<AffineBinaryOpExpr>();
201     return op.getLHS().isPureAffine() &&
202            op.getRHS().template isa<AffineConstantExpr>();
203   }
204   }
205   llvm_unreachable("Unknown AffineExpr");
206 }
207 
208 // Returns the greatest known integral divisor of this affine expression.
getLargestKnownDivisor() const209 int64_t AffineExpr::getLargestKnownDivisor() const {
210   AffineBinaryOpExpr binExpr(nullptr);
211   switch (getKind()) {
212   case AffineExprKind::SymbolId:
213     LLVM_FALLTHROUGH;
214   case AffineExprKind::DimId:
215     return 1;
216   case AffineExprKind::Constant:
217     return std::abs(this->cast<AffineConstantExpr>().getValue());
218   case AffineExprKind::Mul: {
219     binExpr = this->cast<AffineBinaryOpExpr>();
220     return binExpr.getLHS().getLargestKnownDivisor() *
221            binExpr.getRHS().getLargestKnownDivisor();
222   }
223   case AffineExprKind::Add:
224     LLVM_FALLTHROUGH;
225   case AffineExprKind::FloorDiv:
226   case AffineExprKind::CeilDiv:
227   case AffineExprKind::Mod: {
228     binExpr = cast<AffineBinaryOpExpr>();
229     return llvm::GreatestCommonDivisor64(
230         binExpr.getLHS().getLargestKnownDivisor(),
231         binExpr.getRHS().getLargestKnownDivisor());
232   }
233   }
234   llvm_unreachable("Unknown AffineExpr");
235 }
236 
isMultipleOf(int64_t factor) const237 bool AffineExpr::isMultipleOf(int64_t factor) const {
238   AffineBinaryOpExpr binExpr(nullptr);
239   uint64_t l, u;
240   switch (getKind()) {
241   case AffineExprKind::SymbolId:
242     LLVM_FALLTHROUGH;
243   case AffineExprKind::DimId:
244     return factor * factor == 1;
245   case AffineExprKind::Constant:
246     return cast<AffineConstantExpr>().getValue() % factor == 0;
247   case AffineExprKind::Mul: {
248     binExpr = cast<AffineBinaryOpExpr>();
249     // It's probably not worth optimizing this further (to not traverse the
250     // whole sub-tree under - it that would require a version of isMultipleOf
251     // that on a 'false' return also returns the largest known divisor).
252     return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
253            (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
254            (l * u) % factor == 0;
255   }
256   case AffineExprKind::Add:
257   case AffineExprKind::FloorDiv:
258   case AffineExprKind::CeilDiv:
259   case AffineExprKind::Mod: {
260     binExpr = cast<AffineBinaryOpExpr>();
261     return llvm::GreatestCommonDivisor64(
262                binExpr.getLHS().getLargestKnownDivisor(),
263                binExpr.getRHS().getLargestKnownDivisor()) %
264                factor ==
265            0;
266   }
267   }
268   llvm_unreachable("Unknown AffineExpr");
269 }
270 
isFunctionOfDim(unsigned position) const271 bool AffineExpr::isFunctionOfDim(unsigned position) const {
272   if (getKind() == AffineExprKind::DimId) {
273     return *this == mlir::getAffineDimExpr(position, getContext());
274   }
275   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
276     return expr.getLHS().isFunctionOfDim(position) ||
277            expr.getRHS().isFunctionOfDim(position);
278   }
279   return false;
280 }
281 
isFunctionOfSymbol(unsigned position) const282 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
283   if (getKind() == AffineExprKind::SymbolId) {
284     return *this == mlir::getAffineSymbolExpr(position, getContext());
285   }
286   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
287     return expr.getLHS().isFunctionOfSymbol(position) ||
288            expr.getRHS().isFunctionOfSymbol(position);
289   }
290   return false;
291 }
292 
AffineBinaryOpExpr(AffineExpr::ImplType * ptr)293 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
294     : AffineExpr(ptr) {}
getLHS() const295 AffineExpr AffineBinaryOpExpr::getLHS() const {
296   return static_cast<ImplType *>(expr)->lhs;
297 }
getRHS() const298 AffineExpr AffineBinaryOpExpr::getRHS() const {
299   return static_cast<ImplType *>(expr)->rhs;
300 }
301 
AffineDimExpr(AffineExpr::ImplType * ptr)302 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
getPosition() const303 unsigned AffineDimExpr::getPosition() const {
304   return static_cast<ImplType *>(expr)->position;
305 }
306 
307 /// Returns true if the expression is divisible by the given symbol with
308 /// position `symbolPos`. The argument `opKind` specifies here what kind of
309 /// division or mod operation called this division. It helps in implementing the
310 /// commutative property of the floordiv and ceildiv operations. If the argument
311 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
312 /// operation, then the commutative property can be used otherwise, the floordiv
313 /// operation is not divisible. The same argument holds for ceildiv operation.
isDivisibleBySymbol(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)314 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
315                                 AffineExprKind opKind) {
316   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
317   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
318           opKind == AffineExprKind::CeilDiv) &&
319          "unexpected opKind");
320   switch (expr.getKind()) {
321   case AffineExprKind::Constant:
322     if (expr.cast<AffineConstantExpr>().getValue())
323       return false;
324     return true;
325   case AffineExprKind::DimId:
326     return false;
327   case AffineExprKind::SymbolId:
328     return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
329   // Checks divisibility by the given symbol for both operands.
330   case AffineExprKind::Add: {
331     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
332     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
333            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
334   }
335   // Checks divisibility by the given symbol for both operands. Consider the
336   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
337   // this is a division by s1 and both the operands of modulo are divisible by
338   // s1 but it is not divisible by s1 always. The third argument is
339   // `AffineExprKind::Mod` for this reason.
340   case AffineExprKind::Mod: {
341     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
342     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
343                                AffineExprKind::Mod) &&
344            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
345                                AffineExprKind::Mod);
346   }
347   // Checks if any of the operand divisible by the given symbol.
348   case AffineExprKind::Mul: {
349     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
350     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
351            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
352   }
353   // Floordiv and ceildiv are divisible by the given symbol when the first
354   // operand is divisible, and the affine expression kind of the argument expr
355   // is same as the argument `opKind`. This can be inferred from commutative
356   // property of floordiv and ceildiv operations and are as follow:
357   // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
358   // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
359   // It will fail if operations are not same. For example:
360   // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
361   case AffineExprKind::FloorDiv:
362   case AffineExprKind::CeilDiv: {
363     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
364     if (opKind != expr.getKind())
365       return false;
366     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
367   }
368   }
369   llvm_unreachable("Unknown AffineExpr");
370 }
371 
372 /// Divides the given expression by the given symbol at position `symbolPos`. It
373 /// considers the divisibility condition is checked before calling itself. A
374 /// null expression is returned whenever the divisibility condition fails.
symbolicDivide(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)375 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
376                                  AffineExprKind opKind) {
377   // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
378   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
379           opKind == AffineExprKind::CeilDiv) &&
380          "unexpected opKind");
381   switch (expr.getKind()) {
382   case AffineExprKind::Constant:
383     if (expr.cast<AffineConstantExpr>().getValue() != 0)
384       return nullptr;
385     return getAffineConstantExpr(0, expr.getContext());
386   case AffineExprKind::DimId:
387     return nullptr;
388   case AffineExprKind::SymbolId:
389     return getAffineConstantExpr(1, expr.getContext());
390   // Dividing both operands by the given symbol.
391   case AffineExprKind::Add: {
392     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
393     return getAffineBinaryOpExpr(
394         expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
395         symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
396   }
397   // Dividing both operands by the given symbol.
398   case AffineExprKind::Mod: {
399     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
400     return getAffineBinaryOpExpr(
401         expr.getKind(),
402         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
403         symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
404   }
405   // Dividing any of the operand by the given symbol.
406   case AffineExprKind::Mul: {
407     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
408     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
409       return binaryExpr.getLHS() *
410              symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
411     return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
412            binaryExpr.getRHS();
413   }
414   // Dividing first operand only by the given symbol.
415   case AffineExprKind::FloorDiv:
416   case AffineExprKind::CeilDiv: {
417     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
418     return getAffineBinaryOpExpr(
419         expr.getKind(),
420         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
421         binaryExpr.getRHS());
422   }
423   }
424   llvm_unreachable("Unknown AffineExpr");
425 }
426 
427 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
428 /// operations when the second operand simplifies to a symbol and the first
429 /// operand is divisible by that symbol. It can be applied to any semi-affine
430 /// expression. Returned expression can either be a semi-affine or pure affine
431 /// expression.
simplifySemiAffine(AffineExpr expr)432 static AffineExpr simplifySemiAffine(AffineExpr expr) {
433   switch (expr.getKind()) {
434   case AffineExprKind::Constant:
435   case AffineExprKind::DimId:
436   case AffineExprKind::SymbolId:
437     return expr;
438   case AffineExprKind::Add:
439   case AffineExprKind::Mul: {
440     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
441     return getAffineBinaryOpExpr(expr.getKind(),
442                                  simplifySemiAffine(binaryExpr.getLHS()),
443                                  simplifySemiAffine(binaryExpr.getRHS()));
444   }
445   // Check if the simplification of the second operand is a symbol, and the
446   // first operand is divisible by it. If the operation is a modulo, a constant
447   // zero expression is returned. In the case of floordiv and ceildiv, the
448   // symbol from the simplification of the second operand divides the first
449   // operand. Otherwise, simplification is not possible.
450   case AffineExprKind::FloorDiv:
451   case AffineExprKind::CeilDiv:
452   case AffineExprKind::Mod: {
453     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
454     AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
455     AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
456     AffineSymbolExpr symbolExpr =
457         simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
458     if (!symbolExpr)
459       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
460     unsigned symbolPos = symbolExpr.getPosition();
461     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
462       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
463     if (expr.getKind() == AffineExprKind::Mod)
464       return getAffineConstantExpr(0, expr.getContext());
465     return symbolicDivide(sLHS, symbolPos, expr.getKind());
466   }
467   }
468   llvm_unreachable("Unknown AffineExpr");
469 }
470 
getAffineDimOrSymbol(AffineExprKind kind,unsigned position,MLIRContext * context)471 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
472                                        MLIRContext *context) {
473   auto assignCtx = [context](AffineDimExprStorage *storage) {
474     storage->context = context;
475   };
476 
477   StorageUniquer &uniquer = context->getAffineUniquer();
478   return uniquer.get<AffineDimExprStorage>(
479       assignCtx, static_cast<unsigned>(kind), position);
480 }
481 
getAffineDimExpr(unsigned position,MLIRContext * context)482 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
483   return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
484 }
485 
AffineSymbolExpr(AffineExpr::ImplType * ptr)486 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
487     : AffineExpr(ptr) {}
getPosition() const488 unsigned AffineSymbolExpr::getPosition() const {
489   return static_cast<ImplType *>(expr)->position;
490 }
491 
getAffineSymbolExpr(unsigned position,MLIRContext * context)492 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
493   return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
494   ;
495 }
496 
AffineConstantExpr(AffineExpr::ImplType * ptr)497 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
498     : AffineExpr(ptr) {}
getValue() const499 int64_t AffineConstantExpr::getValue() const {
500   return static_cast<ImplType *>(expr)->constant;
501 }
502 
operator ==(int64_t v) const503 bool AffineExpr::operator==(int64_t v) const {
504   return *this == getAffineConstantExpr(v, getContext());
505 }
506 
getAffineConstantExpr(int64_t constant,MLIRContext * context)507 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
508   auto assignCtx = [context](AffineConstantExprStorage *storage) {
509     storage->context = context;
510   };
511 
512   StorageUniquer &uniquer = context->getAffineUniquer();
513   return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
514 }
515 
516 /// Simplify add expression. Return nullptr if it can't be simplified.
simplifyAdd(AffineExpr lhs,AffineExpr rhs)517 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
518   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
519   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
520   // Fold if both LHS, RHS are a constant.
521   if (lhsConst && rhsConst)
522     return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
523                                  lhs.getContext());
524 
525   // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
526   // If only one of them is a symbolic expressions, make it the RHS.
527   if (lhs.isa<AffineConstantExpr>() ||
528       (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
529     return rhs + lhs;
530   }
531 
532   // At this point, if there was a constant, it would be on the right.
533 
534   // Addition with a zero is a noop, return the other input.
535   if (rhsConst) {
536     if (rhsConst.getValue() == 0)
537       return lhs;
538   }
539   // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
540   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
541   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
542     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
543       return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
544   }
545 
546   // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
547   // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
548   // respective multiplicands.
549   Optional<int64_t> rLhsConst, rRhsConst;
550   AffineExpr firstExpr, secondExpr;
551   AffineConstantExpr rLhsConstExpr;
552   auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
553   if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
554       (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
555     rLhsConst = rLhsConstExpr.getValue();
556     firstExpr = lBinOpExpr.getLHS();
557   } else {
558     rLhsConst = 1;
559     firstExpr = lhs;
560   }
561 
562   auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
563   AffineConstantExpr rRhsConstExpr;
564   if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
565       (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
566     rRhsConst = rRhsConstExpr.getValue();
567     secondExpr = rBinOpExpr.getLHS();
568   } else {
569     rRhsConst = 1;
570     secondExpr = rhs;
571   }
572 
573   if (rLhsConst && rRhsConst && firstExpr == secondExpr)
574     return getAffineBinaryOpExpr(
575         AffineExprKind::Mul, firstExpr,
576         getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(),
577                               lhs.getContext()));
578 
579   // When doing successive additions, bring constant to the right: turn (d0 + 2)
580   // + d1 into (d0 + d1) + 2.
581   if (lBin && lBin.getKind() == AffineExprKind::Add) {
582     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
583       return lBin.getLHS() + rhs + lrhs;
584     }
585   }
586 
587   // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
588   // leads to a much more efficient form when 'c' is a power of two, and in
589   // general a more compact and readable form.
590 
591   // Process '(expr floordiv c) * (-c)'.
592   if (!rBinOpExpr)
593     return nullptr;
594 
595   auto lrhs = rBinOpExpr.getLHS();
596   auto rrhs = rBinOpExpr.getRHS();
597 
598   // Process lrhs, which is 'expr floordiv c'.
599   AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
600   if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
601     return nullptr;
602 
603   auto llrhs = lrBinOpExpr.getLHS();
604   auto rlrhs = lrBinOpExpr.getRHS();
605 
606   if (lhs == llrhs && rlrhs == -rrhs) {
607     return lhs % rlrhs;
608   }
609   return nullptr;
610 }
611 
operator +(int64_t v) const612 AffineExpr AffineExpr::operator+(int64_t v) const {
613   return *this + getAffineConstantExpr(v, getContext());
614 }
operator +(AffineExpr other) const615 AffineExpr AffineExpr::operator+(AffineExpr other) const {
616   if (auto simplified = simplifyAdd(*this, other))
617     return simplified;
618 
619   StorageUniquer &uniquer = getContext()->getAffineUniquer();
620   return uniquer.get<AffineBinaryOpExprStorage>(
621       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
622 }
623 
624 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
simplifyMul(AffineExpr lhs,AffineExpr rhs)625 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
626   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
627   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
628 
629   if (lhsConst && rhsConst)
630     return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
631                                  lhs.getContext());
632 
633   assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
634 
635   // Canonicalize the mul expression so that the constant/symbolic term is the
636   // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
637   // constant. (Note that a constant is trivially symbolic).
638   if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
639     // At least one of them has to be symbolic.
640     return rhs * lhs;
641   }
642 
643   // At this point, if there was a constant, it would be on the right.
644 
645   // Multiplication with a one is a noop, return the other input.
646   if (rhsConst) {
647     if (rhsConst.getValue() == 1)
648       return lhs;
649     // Multiplication with zero.
650     if (rhsConst.getValue() == 0)
651       return rhsConst;
652   }
653 
654   // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
655   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
656   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
657     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
658       return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
659   }
660 
661   // When doing successive multiplication, bring constant to the right: turn (d0
662   // * 2) * d1 into (d0 * d1) * 2.
663   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
664     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
665       return (lBin.getLHS() * rhs) * lrhs;
666     }
667   }
668 
669   return nullptr;
670 }
671 
operator *(int64_t v) const672 AffineExpr AffineExpr::operator*(int64_t v) const {
673   return *this * getAffineConstantExpr(v, getContext());
674 }
operator *(AffineExpr other) const675 AffineExpr AffineExpr::operator*(AffineExpr other) const {
676   if (auto simplified = simplifyMul(*this, other))
677     return simplified;
678 
679   StorageUniquer &uniquer = getContext()->getAffineUniquer();
680   return uniquer.get<AffineBinaryOpExprStorage>(
681       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
682 }
683 
684 // Unary minus, delegate to operator*.
operator -() const685 AffineExpr AffineExpr::operator-() const {
686   return *this * getAffineConstantExpr(-1, getContext());
687 }
688 
689 // Delegate to operator+.
operator -(int64_t v) const690 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
operator -(AffineExpr other) const691 AffineExpr AffineExpr::operator-(AffineExpr other) const {
692   return *this + (-other);
693 }
694 
simplifyFloorDiv(AffineExpr lhs,AffineExpr rhs)695 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
696   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
697   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
698 
699   // mlir floordiv by zero or negative numbers is undefined and preserved as is.
700   if (!rhsConst || rhsConst.getValue() < 1)
701     return nullptr;
702 
703   if (lhsConst)
704     return getAffineConstantExpr(
705         floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
706 
707   // Fold floordiv of a multiply with a constant that is a multiple of the
708   // divisor. Eg: (i * 128) floordiv 64 = i * 2.
709   if (rhsConst == 1)
710     return lhs;
711 
712   // Simplify (expr * const) floordiv divConst when expr is known to be a
713   // multiple of divConst.
714   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
715   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
716     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
717       // rhsConst is known to be a positive constant.
718       if (lrhs.getValue() % rhsConst.getValue() == 0)
719         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
720     }
721   }
722 
723   // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
724   // known to be a multiple of divConst.
725   if (lBin && lBin.getKind() == AffineExprKind::Add) {
726     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
727     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
728     // rhsConst is known to be a positive constant.
729     if (llhsDiv % rhsConst.getValue() == 0 ||
730         lrhsDiv % rhsConst.getValue() == 0)
731       return lBin.getLHS().floorDiv(rhsConst.getValue()) +
732              lBin.getRHS().floorDiv(rhsConst.getValue());
733   }
734 
735   return nullptr;
736 }
737 
floorDiv(uint64_t v) const738 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
739   return floorDiv(getAffineConstantExpr(v, getContext()));
740 }
floorDiv(AffineExpr other) const741 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
742   if (auto simplified = simplifyFloorDiv(*this, other))
743     return simplified;
744 
745   StorageUniquer &uniquer = getContext()->getAffineUniquer();
746   return uniquer.get<AffineBinaryOpExprStorage>(
747       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
748       other);
749 }
750 
simplifyCeilDiv(AffineExpr lhs,AffineExpr rhs)751 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
752   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
753   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
754 
755   if (!rhsConst || rhsConst.getValue() < 1)
756     return nullptr;
757 
758   if (lhsConst)
759     return getAffineConstantExpr(
760         ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
761 
762   // Fold ceildiv of a multiply with a constant that is a multiple of the
763   // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
764   if (rhsConst.getValue() == 1)
765     return lhs;
766 
767   // Simplify (expr * const) ceildiv divConst when const is known to be a
768   // multiple of divConst.
769   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
770   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
771     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
772       // rhsConst is known to be a positive constant.
773       if (lrhs.getValue() % rhsConst.getValue() == 0)
774         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
775     }
776   }
777 
778   return nullptr;
779 }
780 
ceilDiv(uint64_t v) const781 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
782   return ceilDiv(getAffineConstantExpr(v, getContext()));
783 }
ceilDiv(AffineExpr other) const784 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
785   if (auto simplified = simplifyCeilDiv(*this, other))
786     return simplified;
787 
788   StorageUniquer &uniquer = getContext()->getAffineUniquer();
789   return uniquer.get<AffineBinaryOpExprStorage>(
790       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
791       other);
792 }
793 
simplifyMod(AffineExpr lhs,AffineExpr rhs)794 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
795   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
796   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
797 
798   // mod w.r.t zero or negative numbers is undefined and preserved as is.
799   if (!rhsConst || rhsConst.getValue() < 1)
800     return nullptr;
801 
802   if (lhsConst)
803     return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
804                                  lhs.getContext());
805 
806   // Fold modulo of an expression that is known to be a multiple of a constant
807   // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
808   // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
809   if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
810     return getAffineConstantExpr(0, lhs.getContext());
811 
812   // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
813   // known to be a multiple of divConst.
814   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
815   if (lBin && lBin.getKind() == AffineExprKind::Add) {
816     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
817     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
818     // rhsConst is known to be a positive constant.
819     if (llhsDiv % rhsConst.getValue() == 0)
820       return lBin.getRHS() % rhsConst.getValue();
821     if (lrhsDiv % rhsConst.getValue() == 0)
822       return lBin.getLHS() % rhsConst.getValue();
823   }
824 
825   return nullptr;
826 }
827 
operator %(uint64_t v) const828 AffineExpr AffineExpr::operator%(uint64_t v) const {
829   return *this % getAffineConstantExpr(v, getContext());
830 }
operator %(AffineExpr other) const831 AffineExpr AffineExpr::operator%(AffineExpr other) const {
832   if (auto simplified = simplifyMod(*this, other))
833     return simplified;
834 
835   StorageUniquer &uniquer = getContext()->getAffineUniquer();
836   return uniquer.get<AffineBinaryOpExprStorage>(
837       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
838 }
839 
compose(AffineMap map) const840 AffineExpr AffineExpr::compose(AffineMap map) const {
841   SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
842                                              map.getResults().end());
843   return replaceDimsAndSymbols(dimReplacements, {});
844 }
operator <<(raw_ostream & os,AffineExpr expr)845 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
846   expr.print(os);
847   return os;
848 }
849 
850 /// Constructs an affine expression from a flat ArrayRef. If there are local
851 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
852 /// products expression, `localExprs` is expected to have the AffineExpr
853 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
854 /// in the format [dims, symbols, locals, constant term].
getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,unsigned numDims,unsigned numSymbols,ArrayRef<AffineExpr> localExprs,MLIRContext * context)855 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
856                                            unsigned numDims,
857                                            unsigned numSymbols,
858                                            ArrayRef<AffineExpr> localExprs,
859                                            MLIRContext *context) {
860   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
861   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
862          "unexpected number of local expressions");
863 
864   auto expr = getAffineConstantExpr(0, context);
865   // Dimensions and symbols.
866   for (unsigned j = 0; j < numDims + numSymbols; j++) {
867     if (flatExprs[j] == 0)
868       continue;
869     auto id = j < numDims ? getAffineDimExpr(j, context)
870                           : getAffineSymbolExpr(j - numDims, context);
871     expr = expr + id * flatExprs[j];
872   }
873 
874   // Local identifiers.
875   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
876        j++) {
877     if (flatExprs[j] == 0)
878       continue;
879     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
880     expr = expr + term;
881   }
882 
883   // Constant term.
884   int64_t constTerm = flatExprs[flatExprs.size() - 1];
885   if (constTerm != 0)
886     expr = expr + constTerm;
887   return expr;
888 }
889 
SimpleAffineExprFlattener(unsigned numDims,unsigned numSymbols)890 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
891                                                      unsigned numSymbols)
892     : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
893   operandExprStack.reserve(8);
894 }
895 
visitMulExpr(AffineBinaryOpExpr expr)896 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
897   assert(operandExprStack.size() >= 2);
898   // This is a pure affine expr; the RHS will be a constant.
899   assert(expr.getRHS().isa<AffineConstantExpr>());
900   // Get the RHS constant.
901   auto rhsConst = operandExprStack.back()[getConstantIndex()];
902   operandExprStack.pop_back();
903   // Update the LHS in place instead of pop and push.
904   auto &lhs = operandExprStack.back();
905   for (unsigned i = 0, e = lhs.size(); i < e; i++) {
906     lhs[i] *= rhsConst;
907   }
908 }
909 
visitAddExpr(AffineBinaryOpExpr expr)910 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
911   assert(operandExprStack.size() >= 2);
912   const auto &rhs = operandExprStack.back();
913   auto &lhs = operandExprStack[operandExprStack.size() - 2];
914   assert(lhs.size() == rhs.size());
915   // Update the LHS in place.
916   for (unsigned i = 0, e = rhs.size(); i < e; i++) {
917     lhs[i] += rhs[i];
918   }
919   // Pop off the RHS.
920   operandExprStack.pop_back();
921 }
922 
923 //
924 // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
925 //
926 // A mod expression "expr mod c" is thus flattened by introducing a new local
927 // variable q (= expr floordiv c), such that expr mod c is replaced with
928 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
visitModExpr(AffineBinaryOpExpr expr)929 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
930   assert(operandExprStack.size() >= 2);
931   // This is a pure affine expr; the RHS will be a constant.
932   assert(expr.getRHS().isa<AffineConstantExpr>());
933   auto rhsConst = operandExprStack.back()[getConstantIndex()];
934   operandExprStack.pop_back();
935   auto &lhs = operandExprStack.back();
936   // TODO: handle modulo by zero case when this issue is fixed
937   // at the other places in the IR.
938   assert(rhsConst > 0 && "RHS constant has to be positive");
939 
940   // Check if the LHS expression is a multiple of modulo factor.
941   unsigned i, e;
942   for (i = 0, e = lhs.size(); i < e; i++)
943     if (lhs[i] % rhsConst != 0)
944       break;
945   // If yes, modulo expression here simplifies to zero.
946   if (i == lhs.size()) {
947     std::fill(lhs.begin(), lhs.end(), 0);
948     return;
949   }
950 
951   // Add a local variable for the quotient, i.e., expr % c is replaced by
952   // (expr - q * c) where q = expr floordiv c. Do this while canceling out
953   // the GCD of expr and c.
954   SmallVector<int64_t, 8> floorDividend(lhs);
955   uint64_t gcd = rhsConst;
956   for (unsigned i = 0, e = lhs.size(); i < e; i++)
957     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
958   // Simplify the numerator and the denominator.
959   if (gcd != 1) {
960     for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
961       floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
962   }
963   int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
964 
965   // Construct the AffineExpr form of the floordiv to store in localExprs.
966   MLIRContext *context = expr.getContext();
967   auto dividendExpr = getAffineExprFromFlatForm(
968       floorDividend, numDims, numSymbols, localExprs, context);
969   auto divisorExpr = getAffineConstantExpr(floorDivisor, context);
970   auto floorDivExpr = dividendExpr.floorDiv(divisorExpr);
971   int loc;
972   if ((loc = findLocalId(floorDivExpr)) == -1) {
973     addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
974     // Set result at top of stack to "lhs - rhsConst * q".
975     lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
976   } else {
977     // Reuse the existing local id.
978     lhs[getLocalVarStartIndex() + loc] = -rhsConst;
979   }
980 }
981 
visitCeilDivExpr(AffineBinaryOpExpr expr)982 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
983   visitDivExpr(expr, /*isCeil=*/true);
984 }
visitFloorDivExpr(AffineBinaryOpExpr expr)985 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
986   visitDivExpr(expr, /*isCeil=*/false);
987 }
988 
visitDimExpr(AffineDimExpr expr)989 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
990   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
991   auto &eq = operandExprStack.back();
992   assert(expr.getPosition() < numDims && "Inconsistent number of dims");
993   eq[getDimStartIndex() + expr.getPosition()] = 1;
994 }
995 
visitSymbolExpr(AffineSymbolExpr expr)996 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
997   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
998   auto &eq = operandExprStack.back();
999   assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1000   eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1001 }
1002 
visitConstantExpr(AffineConstantExpr expr)1003 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1004   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1005   auto &eq = operandExprStack.back();
1006   eq[getConstantIndex()] = expr.getValue();
1007 }
1008 
1009 // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
1010 // A floordiv is thus flattened by introducing a new local variable q, and
1011 // replacing that expression with 'q' while adding the constraints
1012 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1013 // FlatAffineConstraints::addLocalFloorDiv).
1014 //
1015 // A ceildiv is similarly flattened:
1016 // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
visitDivExpr(AffineBinaryOpExpr expr,bool isCeil)1017 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1018                                              bool isCeil) {
1019   assert(operandExprStack.size() >= 2);
1020   assert(expr.getRHS().isa<AffineConstantExpr>());
1021 
1022   // This is a pure affine expr; the RHS is a positive constant.
1023   int64_t rhsConst = operandExprStack.back()[getConstantIndex()];
1024   // TODO: handle division by zero at the same time the issue is
1025   // fixed at other places.
1026   assert(rhsConst > 0 && "RHS constant has to be positive");
1027   operandExprStack.pop_back();
1028   auto &lhs = operandExprStack.back();
1029 
1030   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1031   // common divisors of the numerator and denominator.
1032   uint64_t gcd = std::abs(rhsConst);
1033   for (unsigned i = 0, e = lhs.size(); i < e; i++)
1034     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1035   // Simplify the numerator and the denominator.
1036   if (gcd != 1) {
1037     for (unsigned i = 0, e = lhs.size(); i < e; i++)
1038       lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1039   }
1040   int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1041   // If the divisor becomes 1, the updated LHS is the result. (The
1042   // divisor can't be negative since rhsConst is positive).
1043   if (divisor == 1)
1044     return;
1045 
1046   // If the divisor cannot be simplified to one, we will have to retain
1047   // the ceil/floor expr (simplified up until here). Add an existential
1048   // quantifier to express its result, i.e., expr1 div expr2 is replaced
1049   // by a new identifier, q.
1050   MLIRContext *context = expr.getContext();
1051   auto a =
1052       getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
1053   auto b = getAffineConstantExpr(divisor, context);
1054 
1055   int loc;
1056   auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1057   if ((loc = findLocalId(divExpr)) == -1) {
1058     if (!isCeil) {
1059       SmallVector<int64_t, 8> dividend(lhs);
1060       addLocalFloorDivId(dividend, divisor, divExpr);
1061     } else {
1062       // lhs ceildiv c <=>  (lhs + c - 1) floordiv c
1063       SmallVector<int64_t, 8> dividend(lhs);
1064       dividend.back() += divisor - 1;
1065       addLocalFloorDivId(dividend, divisor, divExpr);
1066     }
1067   }
1068   // Set the expression on stack to the local var introduced to capture the
1069   // result of the division (floor or ceil).
1070   std::fill(lhs.begin(), lhs.end(), 0);
1071   if (loc == -1)
1072     lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1073   else
1074     lhs[getLocalVarStartIndex() + loc] = 1;
1075 }
1076 
1077 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1078 // The local identifier added is always a floordiv of a pure add/mul affine
1079 // function of other identifiers, coefficients of which are specified in
1080 // dividend and with respect to a positive constant divisor. localExpr is the
1081 // simplified tree expression (AffineExpr) corresponding to the quantifier.
addLocalFloorDivId(ArrayRef<int64_t> dividend,int64_t divisor,AffineExpr localExpr)1082 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1083                                                    int64_t divisor,
1084                                                    AffineExpr localExpr) {
1085   assert(divisor > 0 && "positive constant divisor expected");
1086   for (auto &subExpr : operandExprStack)
1087     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1088   localExprs.push_back(localExpr);
1089   numLocals++;
1090   // dividend and divisor are not used here; an override of this method uses it.
1091 }
1092 
findLocalId(AffineExpr localExpr)1093 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1094   SmallVectorImpl<AffineExpr>::iterator it;
1095   if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1096     return -1;
1097   return it - localExprs.begin();
1098 }
1099 
1100 /// Simplify the affine expression by flattening it and reconstructing it.
simplifyAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols)1101 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1102                                     unsigned numSymbols) {
1103   // Simplify semi-affine expressions separately.
1104   if (!expr.isPureAffine())
1105     expr = simplifySemiAffine(expr);
1106   if (!expr.isPureAffine())
1107     return expr;
1108 
1109   SimpleAffineExprFlattener flattener(numDims, numSymbols);
1110   flattener.walkPostOrder(expr);
1111   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1112   auto simplifiedExpr =
1113       getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1114                                 flattener.localExprs, expr.getContext());
1115   flattener.operandExprStack.pop_back();
1116   assert(flattener.operandExprStack.empty());
1117 
1118   return simplifiedExpr;
1119 }
1120