1 //===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // An affine expression is an affine combination of dimension identifiers and
10 // symbols, including ceildiv/floordiv/mod by a constant integer.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef MLIR_IR_AFFINE_EXPR_H
15 #define MLIR_IR_AFFINE_EXPR_H
16
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/DenseMapInfo.h"
19 #include "llvm/Support/Casting.h"
20 #include <functional>
21 #include <type_traits>
22
23 namespace mlir {
24
25 class MLIRContext;
26 class AffineMap;
27 class IntegerSet;
28
29 namespace detail {
30
31 struct AffineExprStorage;
32 struct AffineBinaryOpExprStorage;
33 struct AffineDimExprStorage;
34 struct AffineSymbolExprStorage;
35 struct AffineConstantExprStorage;
36
37 } // namespace detail
38
39 enum class AffineExprKind {
40 Add,
41 /// RHS of mul is always a constant or a symbolic expression.
42 Mul,
43 /// RHS of mod is always a constant or a symbolic expression with a positive
44 /// value.
45 Mod,
46 /// RHS of floordiv is always a constant or a symbolic expression.
47 FloorDiv,
48 /// RHS of ceildiv is always a constant or a symbolic expression.
49 CeilDiv,
50
51 /// This is a marker for the last affine binary op. The range of binary
52 /// op's is expected to be this element and earlier.
53 LAST_AFFINE_BINARY_OP = CeilDiv,
54
55 /// Constant integer.
56 Constant,
57 /// Dimensional identifier.
58 DimId,
59 /// Symbolic identifier.
60 SymbolId,
61 };
62
63 /// Base type for affine expression.
64 /// AffineExpr's are immutable value types with intuitive operators to
65 /// operate on chainable, lightweight compositions.
66 /// An AffineExpr is an interface to the underlying storage type pointer.
67 class AffineExpr {
68 public:
69 using ImplType = detail::AffineExprStorage;
70
AffineExpr()71 constexpr AffineExpr() : expr(nullptr) {}
AffineExpr(const ImplType * expr)72 /* implicit */ AffineExpr(const ImplType *expr)
73 : expr(const_cast<ImplType *>(expr)) {}
74
75 bool operator==(AffineExpr other) const { return expr == other.expr; }
76 bool operator!=(AffineExpr other) const { return !(*this == other); }
77 bool operator==(int64_t v) const;
78 bool operator!=(int64_t v) const { return !(*this == v); }
79 explicit operator bool() const { return expr; }
80
81 bool operator!() const { return expr == nullptr; }
82
83 template <typename U>
84 bool isa() const;
85 template <typename U>
86 U dyn_cast() const;
87 template <typename U>
88 U dyn_cast_or_null() const;
89 template <typename U>
90 U cast() const;
91
92 MLIRContext *getContext() const;
93
94 /// Return the classification for this type.
95 AffineExprKind getKind() const;
96
97 void print(raw_ostream &os) const;
98 void dump() const;
99
100 /// Returns true if this expression is made out of only symbols and
101 /// constants, i.e., it does not involve dimensional identifiers.
102 bool isSymbolicOrConstant() const;
103
104 /// Returns true if this is a pure affine expression, i.e., multiplication,
105 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
106 bool isPureAffine() const;
107
108 /// Returns the greatest known integral divisor of this affine expression. The
109 /// result is always positive.
110 int64_t getLargestKnownDivisor() const;
111
112 /// Return true if the affine expression is a multiple of 'factor'.
113 bool isMultipleOf(int64_t factor) const;
114
115 /// Return true if the affine expression involves AffineDimExpr `position`.
116 bool isFunctionOfDim(unsigned position) const;
117
118 /// Return true if the affine expression involves AffineSymbolExpr `position`.
119 bool isFunctionOfSymbol(unsigned position) const;
120
121 /// Walk all of the AffineExpr's in this expression in postorder.
122 void walk(std::function<void(AffineExpr)> callback) const;
123
124 /// This method substitutes any uses of dimensions and symbols (e.g.
125 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
126 /// This is a dense replacement method: a replacement must be specified for
127 /// every single dim and symbol.
128 AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
129 ArrayRef<AffineExpr> symReplacements) const;
130
131 /// Dim-only version of replaceDimsAndSymbols.
132 AffineExpr replaceDims(ArrayRef<AffineExpr> dimReplacements) const;
133
134 /// Symbol-only version of replaceDimsAndSymbols.
135 AffineExpr replaceSymbols(ArrayRef<AffineExpr> symReplacements) const;
136
137 /// Sparse replace method. Replace `expr` by `replacement` and return the
138 /// modified expression tree.
139 AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;
140
141 /// Sparse replace method. If `*this` appears in `map` replaces it by
142 /// `map[*this]` and return the modified expression tree. Otherwise traverse
143 /// `*this` and apply replace with `map` on its subexpressions.
144 AffineExpr replace(const DenseMap<AffineExpr, AffineExpr> &map) const;
145
146 /// Replace dims[0 .. numDims - 1] by dims[shift .. shift + numDims - 1].
147 AffineExpr shiftDims(unsigned numDims, unsigned shift) const;
148
149 /// Replace symbols[0 .. numSymbols - 1] by
150 /// symbols[shift .. shift + numSymbols - 1].
151 AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const;
152
153 AffineExpr operator+(int64_t v) const;
154 AffineExpr operator+(AffineExpr other) const;
155 AffineExpr operator-() const;
156 AffineExpr operator-(int64_t v) const;
157 AffineExpr operator-(AffineExpr other) const;
158 AffineExpr operator*(int64_t v) const;
159 AffineExpr operator*(AffineExpr other) const;
160 AffineExpr floorDiv(uint64_t v) const;
161 AffineExpr floorDiv(AffineExpr other) const;
162 AffineExpr ceilDiv(uint64_t v) const;
163 AffineExpr ceilDiv(AffineExpr other) const;
164 AffineExpr operator%(uint64_t v) const;
165 AffineExpr operator%(AffineExpr other) const;
166
167 /// Compose with an AffineMap.
168 /// Returns the composition of this AffineExpr with `map`.
169 ///
170 /// Prerequisites:
171 /// `this` and `map` are composable, i.e. that the number of AffineDimExpr of
172 /// `this` is smaller than the number of results of `map`. If a result of a
173 /// map does not have a corresponding AffineDimExpr, that result simply does
174 /// not appear in the produced AffineExpr.
175 ///
176 /// Example:
177 /// expr: `d0 + d2`
178 /// map: `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)`
179 /// returned expr: `d0 * 2 + d1 + d2 + s1`
180 AffineExpr compose(AffineMap map) const;
181
182 friend ::llvm::hash_code hash_value(AffineExpr arg);
183
184 /// Methods supporting C API.
getAsOpaquePointer()185 const void *getAsOpaquePointer() const {
186 return static_cast<const void *>(expr);
187 }
getFromOpaquePointer(const void * pointer)188 static AffineExpr getFromOpaquePointer(const void *pointer) {
189 return AffineExpr(
190 reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
191 }
192
193 protected:
194 ImplType *expr;
195 };
196
197 /// Affine binary operation expression. An affine binary operation could be an
198 /// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
199 /// represented through a multiply by -1 and add.) These expressions are always
200 /// constructed in a simplified form. For eg., the LHS and RHS operands can't
201 /// both be constants. There are additional canonicalizing rules depending on
202 /// the op type: see checks in the constructor.
203 class AffineBinaryOpExpr : public AffineExpr {
204 public:
205 using ImplType = detail::AffineBinaryOpExprStorage;
206 /* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
207 AffineExpr getLHS() const;
208 AffineExpr getRHS() const;
209 };
210
211 /// A dimensional identifier appearing in an affine expression.
212 class AffineDimExpr : public AffineExpr {
213 public:
214 using ImplType = detail::AffineDimExprStorage;
215 /* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
216 unsigned getPosition() const;
217 };
218
219 /// A symbolic identifier appearing in an affine expression.
220 class AffineSymbolExpr : public AffineExpr {
221 public:
222 using ImplType = detail::AffineDimExprStorage;
223 /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
224 unsigned getPosition() const;
225 };
226
227 /// An integer constant appearing in affine expression.
228 class AffineConstantExpr : public AffineExpr {
229 public:
230 using ImplType = detail::AffineConstantExprStorage;
231 /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
232 int64_t getValue() const;
233 };
234
235 /// Make AffineExpr hashable.
hash_value(AffineExpr arg)236 inline ::llvm::hash_code hash_value(AffineExpr arg) {
237 return ::llvm::hash_value(arg.expr);
238 }
239
240 inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; }
241 inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; }
242 inline AffineExpr operator-(int64_t val, AffineExpr expr) {
243 return expr * (-1) + val;
244 }
245
246 /// These free functions allow clients of the API to not use classes in detail.
247 AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
248 AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
249 AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
250 AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
251 AffineExpr rhs);
252
253 /// Constructs an affine expression from a flat ArrayRef. If there are local
254 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
255 /// products expression, 'localExprs' is expected to have the AffineExpr
256 /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
257 /// format [dims, symbols, locals, constant term].
258 AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
259 unsigned numDims, unsigned numSymbols,
260 ArrayRef<AffineExpr> localExprs,
261 MLIRContext *context);
262
263 raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
264
265 template <typename U>
isa()266 bool AffineExpr::isa() const {
267 if (std::is_same<U, AffineBinaryOpExpr>::value)
268 return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
269 if (std::is_same<U, AffineDimExpr>::value)
270 return getKind() == AffineExprKind::DimId;
271 if (std::is_same<U, AffineSymbolExpr>::value)
272 return getKind() == AffineExprKind::SymbolId;
273 if (std::is_same<U, AffineConstantExpr>::value)
274 return getKind() == AffineExprKind::Constant;
275 }
276 template <typename U>
dyn_cast()277 U AffineExpr::dyn_cast() const {
278 if (isa<U>())
279 return U(expr);
280 return U(nullptr);
281 }
282 template <typename U>
dyn_cast_or_null()283 U AffineExpr::dyn_cast_or_null() const {
284 return (!*this || !isa<U>()) ? U(nullptr) : U(expr);
285 }
286 template <typename U>
cast()287 U AffineExpr::cast() const {
288 assert(isa<U>());
289 return U(expr);
290 }
291
292 /// Simplify an affine expression by flattening and some amount of
293 /// simple analysis. This has complexity linear in the number of nodes in
294 /// 'expr'. Returns the simplified expression, which is the same as the input
295 /// expression if it can't be simplified.
296 AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
297 unsigned numSymbols);
298
299 namespace detail {
300 template <int N>
bindDims(MLIRContext * ctx)301 void bindDims(MLIRContext *ctx) {}
302
303 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindDims(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)304 void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
305 e = getAffineDimExpr(N, ctx);
306 bindDims<N + 1, AffineExprTy2 &...>(ctx, exprs...);
307 }
308
309 template <int N>
bindSymbols(MLIRContext * ctx)310 void bindSymbols(MLIRContext *ctx) {}
311
312 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindSymbols(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)313 void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
314 e = getAffineSymbolExpr(N, ctx);
315 bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
316 }
317 } // namespace detail
318
319 /// Bind a list of AffineExpr references to DimExpr at positions:
320 /// [0 .. sizeof...(exprs)]
321 template <typename... AffineExprTy>
bindDims(MLIRContext * ctx,AffineExprTy &...exprs)322 void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
323 detail::bindDims<0>(ctx, exprs...);
324 }
325
326 /// Bind a list of AffineExpr references to SymbolExpr at positions:
327 /// [0 .. sizeof...(exprs)]
328 template <typename... AffineExprTy>
bindSymbols(MLIRContext * ctx,AffineExprTy &...exprs)329 void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
330 detail::bindSymbols<0>(ctx, exprs...);
331 }
332
333 } // namespace mlir
334
335 namespace llvm {
336
337 // AffineExpr hash just like pointers
338 template <>
339 struct DenseMapInfo<mlir::AffineExpr> {
340 static mlir::AffineExpr getEmptyKey() {
341 auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
342 return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
343 }
344 static mlir::AffineExpr getTombstoneKey() {
345 auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
346 return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
347 }
348 static unsigned getHashValue(mlir::AffineExpr val) {
349 return mlir::hash_value(val);
350 }
351 static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) {
352 return LHS == RHS;
353 }
354 };
355
356 } // namespace llvm
357
358 #endif // MLIR_IR_AFFINE_EXPR_H
359