1 //===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // An affine expression is an affine combination of dimension identifiers and
10 // symbols, including ceildiv/floordiv/mod by a constant integer.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_IR_AFFINE_EXPR_H
15 #define MLIR_IR_AFFINE_EXPR_H
16 
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/DenseMapInfo.h"
19 #include "llvm/Support/Casting.h"
20 #include <functional>
21 #include <type_traits>
22 
23 namespace mlir {
24 
25 class MLIRContext;
26 class AffineMap;
27 class IntegerSet;
28 
29 namespace detail {
30 
31 struct AffineExprStorage;
32 struct AffineBinaryOpExprStorage;
33 struct AffineDimExprStorage;
34 struct AffineSymbolExprStorage;
35 struct AffineConstantExprStorage;
36 
37 } // namespace detail
38 
39 enum class AffineExprKind {
40   Add,
41   /// RHS of mul is always a constant or a symbolic expression.
42   Mul,
43   /// RHS of mod is always a constant or a symbolic expression with a positive
44   /// value.
45   Mod,
46   /// RHS of floordiv is always a constant or a symbolic expression.
47   FloorDiv,
48   /// RHS of ceildiv is always a constant or a symbolic expression.
49   CeilDiv,
50 
51   /// This is a marker for the last affine binary op. The range of binary
52   /// op's is expected to be this element and earlier.
53   LAST_AFFINE_BINARY_OP = CeilDiv,
54 
55   /// Constant integer.
56   Constant,
57   /// Dimensional identifier.
58   DimId,
59   /// Symbolic identifier.
60   SymbolId,
61 };
62 
63 /// Base type for affine expression.
64 /// AffineExpr's are immutable value types with intuitive operators to
65 /// operate on chainable, lightweight compositions.
66 /// An AffineExpr is an interface to the underlying storage type pointer.
67 class AffineExpr {
68 public:
69   using ImplType = detail::AffineExprStorage;
70 
AffineExpr()71   constexpr AffineExpr() : expr(nullptr) {}
AffineExpr(const ImplType * expr)72   /* implicit */ AffineExpr(const ImplType *expr)
73       : expr(const_cast<ImplType *>(expr)) {}
74 
75   bool operator==(AffineExpr other) const { return expr == other.expr; }
76   bool operator!=(AffineExpr other) const { return !(*this == other); }
77   bool operator==(int64_t v) const;
78   bool operator!=(int64_t v) const { return !(*this == v); }
79   explicit operator bool() const { return expr; }
80 
81   bool operator!() const { return expr == nullptr; }
82 
83   template <typename U>
84   bool isa() const;
85   template <typename U>
86   U dyn_cast() const;
87   template <typename U>
88   U dyn_cast_or_null() const;
89   template <typename U>
90   U cast() const;
91 
92   MLIRContext *getContext() const;
93 
94   /// Return the classification for this type.
95   AffineExprKind getKind() const;
96 
97   void print(raw_ostream &os) const;
98   void dump() const;
99 
100   /// Returns true if this expression is made out of only symbols and
101   /// constants, i.e., it does not involve dimensional identifiers.
102   bool isSymbolicOrConstant() const;
103 
104   /// Returns true if this is a pure affine expression, i.e., multiplication,
105   /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
106   bool isPureAffine() const;
107 
108   /// Returns the greatest known integral divisor of this affine expression. The
109   /// result is always positive.
110   int64_t getLargestKnownDivisor() const;
111 
112   /// Return true if the affine expression is a multiple of 'factor'.
113   bool isMultipleOf(int64_t factor) const;
114 
115   /// Return true if the affine expression involves AffineDimExpr `position`.
116   bool isFunctionOfDim(unsigned position) const;
117 
118   /// Return true if the affine expression involves AffineSymbolExpr `position`.
119   bool isFunctionOfSymbol(unsigned position) const;
120 
121   /// Walk all of the AffineExpr's in this expression in postorder.
122   void walk(std::function<void(AffineExpr)> callback) const;
123 
124   /// This method substitutes any uses of dimensions and symbols (e.g.
125   /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
126   /// This is a dense replacement method: a replacement must be specified for
127   /// every single dim and symbol.
128   AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
129                                    ArrayRef<AffineExpr> symReplacements) const;
130 
131   /// Dim-only version of replaceDimsAndSymbols.
132   AffineExpr replaceDims(ArrayRef<AffineExpr> dimReplacements) const;
133 
134   /// Symbol-only version of replaceDimsAndSymbols.
135   AffineExpr replaceSymbols(ArrayRef<AffineExpr> symReplacements) const;
136 
137   /// Sparse replace method. Replace `expr` by `replacement` and return the
138   /// modified expression tree.
139   AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;
140 
141   /// Sparse replace method. If `*this` appears in `map` replaces it by
142   /// `map[*this]` and return the modified expression tree. Otherwise traverse
143   /// `*this` and apply replace with `map` on its subexpressions.
144   AffineExpr replace(const DenseMap<AffineExpr, AffineExpr> &map) const;
145 
146   /// Replace dims[0 .. numDims - 1] by dims[shift .. shift + numDims - 1].
147   AffineExpr shiftDims(unsigned numDims, unsigned shift) const;
148 
149   /// Replace symbols[0 .. numSymbols - 1] by
150   ///         symbols[shift .. shift + numSymbols - 1].
151   AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const;
152 
153   AffineExpr operator+(int64_t v) const;
154   AffineExpr operator+(AffineExpr other) const;
155   AffineExpr operator-() const;
156   AffineExpr operator-(int64_t v) const;
157   AffineExpr operator-(AffineExpr other) const;
158   AffineExpr operator*(int64_t v) const;
159   AffineExpr operator*(AffineExpr other) const;
160   AffineExpr floorDiv(uint64_t v) const;
161   AffineExpr floorDiv(AffineExpr other) const;
162   AffineExpr ceilDiv(uint64_t v) const;
163   AffineExpr ceilDiv(AffineExpr other) const;
164   AffineExpr operator%(uint64_t v) const;
165   AffineExpr operator%(AffineExpr other) const;
166 
167   /// Compose with an AffineMap.
168   /// Returns the composition of this AffineExpr with `map`.
169   ///
170   /// Prerequisites:
171   /// `this` and `map` are composable, i.e. that the number of AffineDimExpr of
172   /// `this` is smaller than the number of results of `map`. If a result of a
173   /// map does not have a corresponding AffineDimExpr, that result simply does
174   /// not appear in the produced AffineExpr.
175   ///
176   /// Example:
177   ///   expr: `d0 + d2`
178   ///   map:  `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)`
179   ///   returned expr: `d0 * 2 + d1 + d2 + s1`
180   AffineExpr compose(AffineMap map) const;
181 
182   friend ::llvm::hash_code hash_value(AffineExpr arg);
183 
184   /// Methods supporting C API.
getAsOpaquePointer()185   const void *getAsOpaquePointer() const {
186     return static_cast<const void *>(expr);
187   }
getFromOpaquePointer(const void * pointer)188   static AffineExpr getFromOpaquePointer(const void *pointer) {
189     return AffineExpr(
190         reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
191   }
192 
193 protected:
194   ImplType *expr;
195 };
196 
197 /// Affine binary operation expression. An affine binary operation could be an
198 /// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
199 /// represented through a multiply by -1 and add.) These expressions are always
200 /// constructed in a simplified form. For eg., the LHS and RHS operands can't
201 /// both be constants. There are additional canonicalizing rules depending on
202 /// the op type: see checks in the constructor.
203 class AffineBinaryOpExpr : public AffineExpr {
204 public:
205   using ImplType = detail::AffineBinaryOpExprStorage;
206   /* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
207   AffineExpr getLHS() const;
208   AffineExpr getRHS() const;
209 };
210 
211 /// A dimensional identifier appearing in an affine expression.
212 class AffineDimExpr : public AffineExpr {
213 public:
214   using ImplType = detail::AffineDimExprStorage;
215   /* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
216   unsigned getPosition() const;
217 };
218 
219 /// A symbolic identifier appearing in an affine expression.
220 class AffineSymbolExpr : public AffineExpr {
221 public:
222   using ImplType = detail::AffineDimExprStorage;
223   /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
224   unsigned getPosition() const;
225 };
226 
227 /// An integer constant appearing in affine expression.
228 class AffineConstantExpr : public AffineExpr {
229 public:
230   using ImplType = detail::AffineConstantExprStorage;
231   /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
232   int64_t getValue() const;
233 };
234 
235 /// Make AffineExpr hashable.
hash_value(AffineExpr arg)236 inline ::llvm::hash_code hash_value(AffineExpr arg) {
237   return ::llvm::hash_value(arg.expr);
238 }
239 
240 inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; }
241 inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; }
242 inline AffineExpr operator-(int64_t val, AffineExpr expr) {
243   return expr * (-1) + val;
244 }
245 
246 /// These free functions allow clients of the API to not use classes in detail.
247 AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
248 AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
249 AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
250 AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
251                                  AffineExpr rhs);
252 
253 /// Constructs an affine expression from a flat ArrayRef. If there are local
254 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
255 /// products expression, 'localExprs' is expected to have the AffineExpr
256 /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
257 /// format [dims, symbols, locals, constant term].
258 AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
259                                      unsigned numDims, unsigned numSymbols,
260                                      ArrayRef<AffineExpr> localExprs,
261                                      MLIRContext *context);
262 
263 raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
264 
265 template <typename U>
isa()266 bool AffineExpr::isa() const {
267   if (std::is_same<U, AffineBinaryOpExpr>::value)
268     return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
269   if (std::is_same<U, AffineDimExpr>::value)
270     return getKind() == AffineExprKind::DimId;
271   if (std::is_same<U, AffineSymbolExpr>::value)
272     return getKind() == AffineExprKind::SymbolId;
273   if (std::is_same<U, AffineConstantExpr>::value)
274     return getKind() == AffineExprKind::Constant;
275 }
276 template <typename U>
dyn_cast()277 U AffineExpr::dyn_cast() const {
278   if (isa<U>())
279     return U(expr);
280   return U(nullptr);
281 }
282 template <typename U>
dyn_cast_or_null()283 U AffineExpr::dyn_cast_or_null() const {
284   return (!*this || !isa<U>()) ? U(nullptr) : U(expr);
285 }
286 template <typename U>
cast()287 U AffineExpr::cast() const {
288   assert(isa<U>());
289   return U(expr);
290 }
291 
292 /// Simplify an affine expression by flattening and some amount of
293 /// simple analysis. This has complexity linear in the number of nodes in
294 /// 'expr'. Returns the simplified expression, which is the same as the input
295 ///  expression if it can't be simplified.
296 AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
297                               unsigned numSymbols);
298 
299 namespace detail {
300 template <int N>
bindDims(MLIRContext * ctx)301 void bindDims(MLIRContext *ctx) {}
302 
303 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindDims(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)304 void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
305   e = getAffineDimExpr(N, ctx);
306   bindDims<N + 1, AffineExprTy2 &...>(ctx, exprs...);
307 }
308 
309 template <int N>
bindSymbols(MLIRContext * ctx)310 void bindSymbols(MLIRContext *ctx) {}
311 
312 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindSymbols(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)313 void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
314   e = getAffineSymbolExpr(N, ctx);
315   bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
316 }
317 } // namespace detail
318 
319 /// Bind a list of AffineExpr references to DimExpr at positions:
320 ///   [0 .. sizeof...(exprs)]
321 template <typename... AffineExprTy>
bindDims(MLIRContext * ctx,AffineExprTy &...exprs)322 void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
323   detail::bindDims<0>(ctx, exprs...);
324 }
325 
326 /// Bind a list of AffineExpr references to SymbolExpr at positions:
327 ///   [0 .. sizeof...(exprs)]
328 template <typename... AffineExprTy>
bindSymbols(MLIRContext * ctx,AffineExprTy &...exprs)329 void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
330   detail::bindSymbols<0>(ctx, exprs...);
331 }
332 
333 } // namespace mlir
334 
335 namespace llvm {
336 
337 // AffineExpr hash just like pointers
338 template <>
339 struct DenseMapInfo<mlir::AffineExpr> {
340   static mlir::AffineExpr getEmptyKey() {
341     auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
342     return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
343   }
344   static mlir::AffineExpr getTombstoneKey() {
345     auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
346     return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
347   }
348   static unsigned getHashValue(mlir::AffineExpr val) {
349     return mlir::hash_value(val);
350   }
351   static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) {
352     return LHS == RHS;
353   }
354 };
355 
356 } // namespace llvm
357 
358 #endif // MLIR_IR_AFFINE_EXPR_H
359