1 //===- AffineOps.h - MLIR Affine Operations -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines convenience types for working with Affine operations
10 // in the MLIR operation set.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
15 #define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
16 
17 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/StandardTypes.h"
24 #include "mlir/Interfaces/LoopLikeInterface.h"
25 #include "mlir/Interfaces/SideEffectInterfaces.h"
26 
27 namespace mlir {
28 class AffineApplyOp;
29 class AffineBound;
30 class AffineDimExpr;
31 class AffineValueMap;
32 class AffineYieldOp;
33 class FlatAffineConstraints;
34 class OpBuilder;
35 
36 /// A utility function to check if a value is defined at the top level of an
37 /// op with trait `AffineScope` or is a region argument for such an op. A value
38 /// of index type defined at the top level is always a valid symbol for all its
39 /// uses.
40 bool isTopLevelValue(Value value);
41 
42 /// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
43 /// from a source memref to a destination memref. The source and destination
44 /// memref need not be of the same dimensionality, but need to have the same
45 /// elemental type. The operands include the source and destination memref's
46 /// each followed by its indices, size of the data transfer in terms of the
47 /// number of elements (of the elemental type of the memref), a tag memref with
48 /// its indices, and optionally at the end, a stride and a
49 /// number_of_elements_per_stride arguments. The tag location is used by an
50 /// AffineDmaWaitOp to check for completion. The indices of the source memref,
51 /// destination memref, and the tag memref have the same restrictions as any
52 /// affine.load/store. In particular, index for each memref dimension must be an
53 /// affine expression of loop induction variables and symbols.
54 /// The optional stride arguments should be of 'index' type, and specify a
55 /// stride for the slower memory space (memory space with a lower memory space
56 /// id), transferring chunks of number_of_elements_per_stride every stride until
57 /// %num_elements are transferred. Either both or no stride arguments should be
58 /// specified. The value of 'num_elements' must be a multiple of
59 /// 'number_of_elements_per_stride'.
60 //
61 // For example, a DmaStartOp operation that transfers 256 elements of a memref
62 // '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory
63 // space 1 at indices [%k + 7, %l], would be specified as follows:
64 //
65 //   %num_elements = constant 256
66 //   %idx = constant 0 : index
67 //   %tag = alloc() : memref<1xi32, 4>
68 //   affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
69 //     %num_elements :
70 //       memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2>
71 //
72 //   If %stride and %num_elt_per_stride are specified, the DMA is expected to
73 //   transfer %num_elt_per_stride elements every %stride elements apart from
74 //   memory space 0 until %num_elements are transferred.
75 //
76 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements,
77 //     %stride, %num_elt_per_stride : ...
78 //
79 // TODO: add additional operands to allow source and destination striding, and
80 // multiple stride levels (possibly using AffineMaps to specify multiple levels
81 // of striding).
82 // TODO: Consider replacing src/dst memref indices with view memrefs.
83 class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
84                                    OpTrait::ZeroResult> {
85 public:
86   using Op::Op;
87 
88   static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
89                     AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
90                     AffineMap dstMap, ValueRange destIndices, Value tagMemRef,
91                     AffineMap tagMap, ValueRange tagIndices, Value numElements,
92                     Value stride = nullptr, Value elementsPerStride = nullptr);
93 
94   /// Returns the operand index of the src memref.
getSrcMemRefOperandIndex()95   unsigned getSrcMemRefOperandIndex() { return 0; }
96 
97   /// Returns the source MemRefType for this DMA operation.
getSrcMemRef()98   Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
getSrcMemRefType()99   MemRefType getSrcMemRefType() {
100     return getSrcMemRef().getType().cast<MemRefType>();
101   }
102 
103   /// Returns the rank (number of indices) of the source MemRefType.
getSrcMemRefRank()104   unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); }
105 
106   /// Returns the affine map used to access the src memref.
getSrcMap()107   AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
getSrcMapAttr()108   AffineMapAttr getSrcMapAttr() {
109     return getAttr(getSrcMapAttrName()).cast<AffineMapAttr>();
110   }
111 
112   /// Returns the source memref affine map indices for this DMA operation.
getSrcIndices()113   operand_range getSrcIndices() {
114     return {operand_begin() + getSrcMemRefOperandIndex() + 1,
115             operand_begin() + getSrcMemRefOperandIndex() + 1 +
116                 getSrcMap().getNumInputs()};
117   }
118 
119   /// Returns the memory space of the src memref.
getSrcMemorySpace()120   unsigned getSrcMemorySpace() {
121     return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
122   }
123 
124   /// Returns the operand index of the dst memref.
getDstMemRefOperandIndex()125   unsigned getDstMemRefOperandIndex() {
126     return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs();
127   }
128 
129   /// Returns the destination MemRefType for this DMA operations.
getDstMemRef()130   Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
getDstMemRefType()131   MemRefType getDstMemRefType() {
132     return getDstMemRef().getType().cast<MemRefType>();
133   }
134 
135   /// Returns the rank (number of indices) of the destination MemRefType.
getDstMemRefRank()136   unsigned getDstMemRefRank() {
137     return getDstMemRef().getType().cast<MemRefType>().getRank();
138   }
139 
140   /// Returns the memory space of the src memref.
getDstMemorySpace()141   unsigned getDstMemorySpace() {
142     return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
143   }
144 
145   /// Returns the affine map used to access the dst memref.
getDstMap()146   AffineMap getDstMap() { return getDstMapAttr().getValue(); }
getDstMapAttr()147   AffineMapAttr getDstMapAttr() {
148     return getAttr(getDstMapAttrName()).cast<AffineMapAttr>();
149   }
150 
151   /// Returns the destination memref indices for this DMA operation.
getDstIndices()152   operand_range getDstIndices() {
153     return {operand_begin() + getDstMemRefOperandIndex() + 1,
154             operand_begin() + getDstMemRefOperandIndex() + 1 +
155                 getDstMap().getNumInputs()};
156   }
157 
158   /// Returns the operand index of the tag memref.
getTagMemRefOperandIndex()159   unsigned getTagMemRefOperandIndex() {
160     return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs();
161   }
162 
163   /// Returns the Tag MemRef for this DMA operation.
getTagMemRef()164   Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
getTagMemRefType()165   MemRefType getTagMemRefType() {
166     return getTagMemRef().getType().cast<MemRefType>();
167   }
168 
169   /// Returns the rank (number of indices) of the tag MemRefType.
getTagMemRefRank()170   unsigned getTagMemRefRank() {
171     return getTagMemRef().getType().cast<MemRefType>().getRank();
172   }
173 
174   /// Returns the affine map used to access the tag memref.
getTagMap()175   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
getTagMapAttr()176   AffineMapAttr getTagMapAttr() {
177     return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
178   }
179 
180   /// Returns the tag memref indices for this DMA operation.
getTagIndices()181   operand_range getTagIndices() {
182     return {operand_begin() + getTagMemRefOperandIndex() + 1,
183             operand_begin() + getTagMemRefOperandIndex() + 1 +
184                 getTagMap().getNumInputs()};
185   }
186 
187   /// Returns the number of elements being transferred by this DMA operation.
getNumElements()188   Value getNumElements() {
189     return getOperand(getTagMemRefOperandIndex() + 1 +
190                       getTagMap().getNumInputs());
191   }
192 
193   /// Returns the AffineMapAttr associated with 'memref'.
getAffineMapAttrForMemRef(Value memref)194   NamedAttribute getAffineMapAttrForMemRef(Value memref) {
195     if (memref == getSrcMemRef())
196       return {Identifier::get(getSrcMapAttrName(), getContext()),
197               getSrcMapAttr()};
198     else if (memref == getDstMemRef())
199       return {Identifier::get(getDstMapAttrName(), getContext()),
200               getDstMapAttr()};
201     assert(memref == getTagMemRef() &&
202            "DmaStartOp expected source, destination or tag memref");
203     return {Identifier::get(getTagMapAttrName(), getContext()),
204             getTagMapAttr()};
205   }
206 
207   /// Returns true if this is a DMA from a faster memory space to a slower one.
isDestMemorySpaceFaster()208   bool isDestMemorySpaceFaster() {
209     return (getSrcMemorySpace() < getDstMemorySpace());
210   }
211 
212   /// Returns true if this is a DMA from a slower memory space to a faster one.
isSrcMemorySpaceFaster()213   bool isSrcMemorySpaceFaster() {
214     // Assumes that a lower number is for a slower memory space.
215     return (getDstMemorySpace() < getSrcMemorySpace());
216   }
217 
218   /// Given a DMA start operation, returns the operand position of either the
219   /// source or destination memref depending on the one that is at the higher
220   /// level of the memory hierarchy. Asserts failure if neither is true.
getFasterMemPos()221   unsigned getFasterMemPos() {
222     assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
223     return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex();
224   }
225 
getSrcMapAttrName()226   static StringRef getSrcMapAttrName() { return "src_map"; }
getDstMapAttrName()227   static StringRef getDstMapAttrName() { return "dst_map"; }
getTagMapAttrName()228   static StringRef getTagMapAttrName() { return "tag_map"; }
229 
getOperationName()230   static StringRef getOperationName() { return "affine.dma_start"; }
231   static ParseResult parse(OpAsmParser &parser, OperationState &result);
232   void print(OpAsmPrinter &p);
233   LogicalResult verify();
234   LogicalResult fold(ArrayRef<Attribute> cstOperands,
235                      SmallVectorImpl<OpFoldResult> &results);
236 
237   /// Returns true if this DMA operation is strided, returns false otherwise.
isStrided()238   bool isStrided() {
239     return getNumOperands() !=
240            getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1;
241   }
242 
243   /// Returns the stride value for this DMA operation.
getStride()244   Value getStride() {
245     if (!isStrided())
246       return nullptr;
247     return getOperand(getNumOperands() - 1 - 1);
248   }
249 
250   /// Returns the number of elements to transfer per stride for this DMA op.
getNumElementsPerStride()251   Value getNumElementsPerStride() {
252     if (!isStrided())
253       return nullptr;
254     return getOperand(getNumOperands() - 1);
255   }
256 };
257 
258 /// AffineDmaWaitOp blocks until the completion of a DMA operation associated
259 /// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be
260 /// an index with the same restrictions as any load/store index. In particular,
261 /// index for each memref dimension must be an affine expression of loop
262 /// induction variables and symbols. %num_elements is the number of elements
263 /// associated with the DMA operation. For example:
264 //
265 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements :
266 //     memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2>
267 //   ...
268 //   ...
269 //   affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
270 //
271 class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
272                                   OpTrait::ZeroResult> {
273 public:
274   using Op::Op;
275 
276   static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
277                     AffineMap tagMap, ValueRange tagIndices, Value numElements);
278 
getOperationName()279   static StringRef getOperationName() { return "affine.dma_wait"; }
280 
281   // Returns the Tag MemRef associated with the DMA operation being waited on.
getTagMemRef()282   Value getTagMemRef() { return getOperand(0); }
getTagMemRefType()283   MemRefType getTagMemRefType() {
284     return getTagMemRef().getType().cast<MemRefType>();
285   }
286 
287   /// Returns the affine map used to access the tag memref.
getTagMap()288   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
getTagMapAttr()289   AffineMapAttr getTagMapAttr() {
290     return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
291   }
292 
293   // Returns the tag memref index for this DMA operation.
getTagIndices()294   operand_range getTagIndices() {
295     return {operand_begin() + 1,
296             operand_begin() + 1 + getTagMap().getNumInputs()};
297   }
298 
299   // Returns the rank (number of indices) of the tag memref.
getTagMemRefRank()300   unsigned getTagMemRefRank() {
301     return getTagMemRef().getType().cast<MemRefType>().getRank();
302   }
303 
304   /// Returns the AffineMapAttr associated with 'memref'.
getAffineMapAttrForMemRef(Value memref)305   NamedAttribute getAffineMapAttrForMemRef(Value memref) {
306     assert(memref == getTagMemRef());
307     return {Identifier::get(getTagMapAttrName(), getContext()),
308             getTagMapAttr()};
309   }
310 
311   /// Returns the number of elements transferred in the associated DMA op.
getNumElements()312   Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
313 
getTagMapAttrName()314   static StringRef getTagMapAttrName() { return "tag_map"; }
315   static ParseResult parse(OpAsmParser &parser, OperationState &result);
316   void print(OpAsmPrinter &p);
317   LogicalResult verify();
318   LogicalResult fold(ArrayRef<Attribute> cstOperands,
319                      SmallVectorImpl<OpFoldResult> &results);
320 };
321 
322 /// Returns true if the given Value can be used as a dimension id in the region
323 /// of the closest surrounding op that has the trait `AffineScope`.
324 bool isValidDim(Value value);
325 
326 /// Returns true if the given Value can be used as a dimension id in `region`,
327 /// i.e., for all its uses in `region`.
328 bool isValidDim(Value value, Region *region);
329 
330 /// Returns true if the given value can be used as a symbol in the region of the
331 /// closest surrounding op that has the trait `AffineScope`.
332 bool isValidSymbol(Value value);
333 
334 /// Returns true if the given Value can be used as a symbol for `region`, i.e.,
335 /// for all its uses in `region`.
336 bool isValidSymbol(Value value, Region *region);
337 
338 /// Modifies both `map` and `operands` in-place so as to:
339 /// 1. drop duplicate operands
340 /// 2. drop unused dims and symbols from map
341 /// 3. promote valid symbols to symbolic operands in case they appeared as
342 ///    dimensional operands
343 /// 4. propagate constant operands and drop them
344 void canonicalizeMapAndOperands(AffineMap *map,
345                                 SmallVectorImpl<Value> *operands);
346 
347 /// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
348 /// for affine maps.
349 void canonicalizeSetAndOperands(IntegerSet *set,
350                                 SmallVectorImpl<Value> *operands);
351 
352 /// Returns a composed AffineApplyOp by composing `map` and `operands` with
353 /// other AffineApplyOps supplying those operands. The operands of the resulting
354 /// AffineApplyOp do not change the length of  AffineApplyOp chains.
355 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
356                                       ArrayRef<Value> operands);
357 
358 /// Given an affine map `map` and its input `operands`, this method composes
359 /// into `map`, maps of AffineApplyOps whose results are the values in
360 /// `operands`, iteratively until no more of `operands` are the result of an
361 /// AffineApplyOp. When this function returns, `map` becomes the composed affine
362 /// map, and each Value in `operands` is guaranteed to be either a loop IV or a
363 /// terminal symbol, i.e., a symbol defined at the top level or a block/function
364 /// argument.
365 void fullyComposeAffineMapAndOperands(AffineMap *map,
366                                       SmallVectorImpl<Value> *operands);
367 
368 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.h.inc"
369 
370 #define GET_OP_CLASSES
371 #include "mlir/Dialect/Affine/IR/AffineOps.h.inc"
372 
373 /// Returns if the provided value is the induction variable of a AffineForOp.
374 bool isForInductionVar(Value val);
375 
376 /// Returns the loop parent of an induction variable. If the provided value is
377 /// not an induction variable, then return nullptr.
378 AffineForOp getForInductionVarOwner(Value val);
379 
380 /// Extracts the induction variables from a list of AffineForOps and places them
381 /// in the output argument `ivs`.
382 void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
383                              SmallVectorImpl<Value> *ivs);
384 
385 /// Builds a perfect nest of affine "for" loops, i.e. each loop except the
386 /// innermost only contains another loop and a terminator. The loops iterate
387 /// from "lbs" to "ubs" with "steps". The body of the innermost loop is
388 /// populated by calling "bodyBuilderFn" and providing it with an OpBuilder, a
389 /// Location and a list of loop induction variables.
390 void buildAffineLoopNest(OpBuilder &builder, Location loc,
391                          ArrayRef<int64_t> lbs, ArrayRef<int64_t> ubs,
392                          ArrayRef<int64_t> steps,
393                          function_ref<void(OpBuilder &, Location, ValueRange)>
394                              bodyBuilderFn = nullptr);
395 void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
396                          ValueRange ubs, ArrayRef<int64_t> steps,
397                          function_ref<void(OpBuilder &, Location, ValueRange)>
398                              bodyBuilderFn = nullptr);
399 
400 /// AffineBound represents a lower or upper bound in the for operation.
401 /// This class does not own the underlying operands. Instead, it refers
402 /// to the operands stored in the AffineForOp. Its life span should not exceed
403 /// that of the for operation it refers to.
404 class AffineBound {
405 public:
getAffineForOp()406   AffineForOp getAffineForOp() { return op; }
getMap()407   AffineMap getMap() { return map; }
408 
getNumOperands()409   unsigned getNumOperands() { return opEnd - opStart; }
getOperand(unsigned idx)410   Value getOperand(unsigned idx) { return op.getOperand(opStart + idx); }
411 
412   using operand_iterator = AffineForOp::operand_iterator;
413   using operand_range = AffineForOp::operand_range;
414 
operand_begin()415   operand_iterator operand_begin() { return op.operand_begin() + opStart; }
operand_end()416   operand_iterator operand_end() { return op.operand_begin() + opEnd; }
getOperands()417   operand_range getOperands() { return {operand_begin(), operand_end()}; }
418 
419 private:
420   // 'affine.for' operation that contains this bound.
421   AffineForOp op;
422   // Start and end positions of this affine bound operands in the list of
423   // the containing 'affine.for' operation operands.
424   unsigned opStart, opEnd;
425   // Affine map for this bound.
426   AffineMap map;
427 
AffineBound(AffineForOp op,unsigned opStart,unsigned opEnd,AffineMap map)428   AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map)
429       : op(op), opStart(opStart), opEnd(opEnd), map(map) {}
430 
431   friend class AffineForOp;
432 };
433 
434 /// An `AffineApplyNormalizer` is a helper class that supports renumbering
435 /// operands of AffineApplyOp. This acts as a reindexing map of Value to
436 /// positional dims or symbols and allows simplifications such as:
437 ///
438 /// ```mlir
439 ///    %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0)
440 /// ```
441 ///
442 /// into:
443 ///
444 /// ```mlir
445 ///    %1 = affine.apply () -> (0)
446 /// ```
447 struct AffineApplyNormalizer {
448   AffineApplyNormalizer(AffineMap map, ArrayRef<Value> operands);
449 
450   /// Returns the AffineMap resulting from normalization.
getAffineMapAffineApplyNormalizer451   AffineMap getAffineMap() { return affineMap; }
452 
getOperandsAffineApplyNormalizer453   SmallVector<Value, 8> getOperands() {
454     SmallVector<Value, 8> res(reorderedDims);
455     res.append(concatenatedSymbols.begin(), concatenatedSymbols.end());
456     return res;
457   }
458 
getNumSymbolsAffineApplyNormalizer459   unsigned getNumSymbols() { return concatenatedSymbols.size(); }
getNumDimsAffineApplyNormalizer460   unsigned getNumDims() { return reorderedDims.size(); }
461 
462   /// Normalizes 'otherMap' and its operands 'otherOperands' to map to this
463   /// normalizer's coordinate space.
464   void normalize(AffineMap *otherMap, SmallVectorImpl<Value> *otherOperands);
465 
466 private:
467   /// Helper function to insert `v` into the coordinate system of the current
468   /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding
469   /// renumbered position.
470   AffineDimExpr renumberOneDim(Value v);
471 
472   /// Given an `other` normalizer, this rewrites `other.affineMap` in the
473   /// coordinate system of the current AffineApplyNormalizer.
474   /// Returns the rewritten AffineMap and updates the dims and symbols of
475   /// `this`.
476   AffineMap renumber(const AffineApplyNormalizer &other);
477 
478   /// Maps of Value to position in `affineMap`.
479   DenseMap<Value, unsigned> dimValueToPosition;
480 
481   /// Ordered dims and symbols matching positional dims and symbols in
482   /// `affineMap`.
483   SmallVector<Value, 8> reorderedDims;
484   SmallVector<Value, 8> concatenatedSymbols;
485 
486   /// The number of symbols in concatenated symbols that belong to the original
487   /// map as opposed to those concatendated during map composition.
488   unsigned numProperSymbols;
489 
490   AffineMap affineMap;
491 
492   /// Used with RAII to control the depth at which AffineApply are composed
493   /// recursively. Only accepts depth 1 for now to allow a behavior where a
494   /// newly composed AffineApplyOp does not increase the length of the chain of
495   /// AffineApplyOps. Full composition is implemented iteratively on top of
496   /// this behavior.
affineApplyDepthAffineApplyNormalizer497   static unsigned &affineApplyDepth() {
498     static thread_local unsigned depth = 0;
499     return depth;
500   }
501   static constexpr unsigned kMaxAffineApplyDepth = 1;
502 
AffineApplyNormalizerAffineApplyNormalizer503   AffineApplyNormalizer() : numProperSymbols(0) { affineApplyDepth()++; }
504 
505 public:
~AffineApplyNormalizerAffineApplyNormalizer506   ~AffineApplyNormalizer() { affineApplyDepth()--; }
507 };
508 
509 } // end namespace mlir
510 
511 #endif
512