1 //===- AffineOps.h - MLIR Affine Operations -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines convenience types for working with Affine operations
10 // in the MLIR operation set.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
15 #define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
16 
17 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/Interfaces/LoopLikeInterface.h"
21 
22 namespace mlir {
23 class AffineApplyOp;
24 class AffineBound;
25 class AffineValueMap;
26 
27 /// A utility function to check if a value is defined at the top level of an
28 /// op with trait `AffineScope` or is a region argument for such an op. A value
29 /// of index type defined at the top level is always a valid symbol for all its
30 /// uses.
31 bool isTopLevelValue(Value value);
32 
33 /// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
34 /// from a source memref to a destination memref. The source and destination
35 /// memref need not be of the same dimensionality, but need to have the same
36 /// elemental type. The operands include the source and destination memref's
37 /// each followed by its indices, size of the data transfer in terms of the
38 /// number of elements (of the elemental type of the memref), a tag memref with
39 /// its indices, and optionally at the end, a stride and a
40 /// number_of_elements_per_stride arguments. The tag location is used by an
41 /// AffineDmaWaitOp to check for completion. The indices of the source memref,
42 /// destination memref, and the tag memref have the same restrictions as any
43 /// affine.load/store. In particular, index for each memref dimension must be an
44 /// affine expression of loop induction variables and symbols.
45 /// The optional stride arguments should be of 'index' type, and specify a
46 /// stride for the slower memory space (memory space with a lower memory space
47 /// id), transferring chunks of number_of_elements_per_stride every stride until
48 /// %num_elements are transferred. Either both or no stride arguments should be
49 /// specified. The value of 'num_elements' must be a multiple of
50 /// 'number_of_elements_per_stride'. If the source and destination locations
51 /// overlap the behavior of this operation is not defined.
52 //
53 // For example, an AffineDmaStartOp operation that transfers 256 elements of a
54 // memref '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in
55 // memory space 1 at indices [%k + 7, %l], would be specified as follows:
56 //
57 //   %num_elements = constant 256
58 //   %idx = constant 0 : index
59 //   %tag = alloc() : memref<1xi32, 4>
60 //   affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
61 //     %num_elements :
62 //       memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2>
63 //
64 //   If %stride and %num_elt_per_stride are specified, the DMA is expected to
65 //   transfer %num_elt_per_stride elements every %stride elements apart from
66 //   memory space 0 until %num_elements are transferred.
67 //
68 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements,
69 //     %stride, %num_elt_per_stride : ...
70 //
71 // TODO: add additional operands to allow source and destination striding, and
72 // multiple stride levels (possibly using AffineMaps to specify multiple levels
73 // of striding).
74 class AffineDmaStartOp
75     : public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
76                 OpTrait::VariadicOperands, OpTrait::ZeroResult,
77                 AffineMapAccessInterface::Trait> {
78 public:
79   using Op::Op;
getAttributeNames()80   static ArrayRef<StringRef> getAttributeNames() { return {}; }
81 
82   static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
83                     AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
84                     AffineMap dstMap, ValueRange destIndices, Value tagMemRef,
85                     AffineMap tagMap, ValueRange tagIndices, Value numElements,
86                     Value stride = nullptr, Value elementsPerStride = nullptr);
87 
88   /// Returns the operand index of the source memref.
getSrcMemRefOperandIndex()89   unsigned getSrcMemRefOperandIndex() { return 0; }
90 
91   /// Returns the source MemRefType for this DMA operation.
getSrcMemRef()92   Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
getSrcMemRefType()93   MemRefType getSrcMemRefType() {
94     return getSrcMemRef().getType().cast<MemRefType>();
95   }
96 
97   /// Returns the rank (number of indices) of the source MemRefType.
getSrcMemRefRank()98   unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); }
99 
100   /// Returns the affine map used to access the source memref.
getSrcMap()101   AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
getSrcMapAttr()102   AffineMapAttr getSrcMapAttr() {
103     return (*this)->getAttr(getSrcMapAttrName()).cast<AffineMapAttr>();
104   }
105 
106   /// Returns the source memref affine map indices for this DMA operation.
getSrcIndices()107   operand_range getSrcIndices() {
108     return {operand_begin() + getSrcMemRefOperandIndex() + 1,
109             operand_begin() + getSrcMemRefOperandIndex() + 1 +
110                 getSrcMap().getNumInputs()};
111   }
112 
113   /// Returns the memory space of the source memref.
getSrcMemorySpace()114   unsigned getSrcMemorySpace() {
115     return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
116   }
117 
118   /// Returns the operand index of the destination memref.
getDstMemRefOperandIndex()119   unsigned getDstMemRefOperandIndex() {
120     return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs();
121   }
122 
123   /// Returns the destination MemRefType for this DMA operation.
getDstMemRef()124   Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
getDstMemRefType()125   MemRefType getDstMemRefType() {
126     return getDstMemRef().getType().cast<MemRefType>();
127   }
128 
129   /// Returns the rank (number of indices) of the destination MemRefType.
getDstMemRefRank()130   unsigned getDstMemRefRank() {
131     return getDstMemRef().getType().cast<MemRefType>().getRank();
132   }
133 
134   /// Returns the memory space of the source memref.
getDstMemorySpace()135   unsigned getDstMemorySpace() {
136     return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
137   }
138 
139   /// Returns the affine map used to access the destination memref.
getDstMap()140   AffineMap getDstMap() { return getDstMapAttr().getValue(); }
getDstMapAttr()141   AffineMapAttr getDstMapAttr() {
142     return (*this)->getAttr(getDstMapAttrName()).cast<AffineMapAttr>();
143   }
144 
145   /// Returns the destination memref indices for this DMA operation.
getDstIndices()146   operand_range getDstIndices() {
147     return {operand_begin() + getDstMemRefOperandIndex() + 1,
148             operand_begin() + getDstMemRefOperandIndex() + 1 +
149                 getDstMap().getNumInputs()};
150   }
151 
152   /// Returns the operand index of the tag memref.
getTagMemRefOperandIndex()153   unsigned getTagMemRefOperandIndex() {
154     return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs();
155   }
156 
157   /// Returns the Tag MemRef for this DMA operation.
getTagMemRef()158   Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
getTagMemRefType()159   MemRefType getTagMemRefType() {
160     return getTagMemRef().getType().cast<MemRefType>();
161   }
162 
163   /// Returns the rank (number of indices) of the tag MemRefType.
getTagMemRefRank()164   unsigned getTagMemRefRank() {
165     return getTagMemRef().getType().cast<MemRefType>().getRank();
166   }
167 
168   /// Returns the affine map used to access the tag memref.
getTagMap()169   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
getTagMapAttr()170   AffineMapAttr getTagMapAttr() {
171     return (*this)->getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
172   }
173 
174   /// Returns the tag memref indices for this DMA operation.
getTagIndices()175   operand_range getTagIndices() {
176     return {operand_begin() + getTagMemRefOperandIndex() + 1,
177             operand_begin() + getTagMemRefOperandIndex() + 1 +
178                 getTagMap().getNumInputs()};
179   }
180 
181   /// Returns the number of elements being transferred by this DMA operation.
getNumElements()182   Value getNumElements() {
183     return getOperand(getTagMemRefOperandIndex() + 1 +
184                       getTagMap().getNumInputs());
185   }
186 
187   /// Impelements the AffineMapAccessInterface.
188   /// Returns the AffineMapAttr associated with 'memref'.
getAffineMapAttrForMemRef(Value memref)189   NamedAttribute getAffineMapAttrForMemRef(Value memref) {
190     if (memref == getSrcMemRef())
191       return {Identifier::get(getSrcMapAttrName(), getContext()),
192               getSrcMapAttr()};
193     if (memref == getDstMemRef())
194       return {Identifier::get(getDstMapAttrName(), getContext()),
195               getDstMapAttr()};
196     assert(memref == getTagMemRef() &&
197            "DmaStartOp expected source, destination or tag memref");
198     return {Identifier::get(getTagMapAttrName(), getContext()),
199             getTagMapAttr()};
200   }
201 
202   /// Returns true if this is a DMA from a faster memory space to a slower one.
isDestMemorySpaceFaster()203   bool isDestMemorySpaceFaster() {
204     return (getSrcMemorySpace() < getDstMemorySpace());
205   }
206 
207   /// Returns true if this is a DMA from a slower memory space to a faster one.
isSrcMemorySpaceFaster()208   bool isSrcMemorySpaceFaster() {
209     // Assumes that a lower number is for a slower memory space.
210     return (getDstMemorySpace() < getSrcMemorySpace());
211   }
212 
213   /// Given a DMA start operation, returns the operand position of either the
214   /// source or destination memref depending on the one that is at the higher
215   /// level of the memory hierarchy. Asserts failure if neither is true.
getFasterMemPos()216   unsigned getFasterMemPos() {
217     assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
218     return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex();
219   }
220 
getSrcMapAttrName()221   static StringRef getSrcMapAttrName() { return "src_map"; }
getDstMapAttrName()222   static StringRef getDstMapAttrName() { return "dst_map"; }
getTagMapAttrName()223   static StringRef getTagMapAttrName() { return "tag_map"; }
224 
getOperationName()225   static StringRef getOperationName() { return "affine.dma_start"; }
226   static ParseResult parse(OpAsmParser &parser, OperationState &result);
227   void print(OpAsmPrinter &p);
228   LogicalResult verify();
229   LogicalResult fold(ArrayRef<Attribute> cstOperands,
230                      SmallVectorImpl<OpFoldResult> &results);
231 
232   /// Returns true if this DMA operation is strided, returns false otherwise.
isStrided()233   bool isStrided() {
234     return getNumOperands() !=
235            getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1;
236   }
237 
238   /// Returns the stride value for this DMA operation.
getStride()239   Value getStride() {
240     if (!isStrided())
241       return nullptr;
242     return getOperand(getNumOperands() - 1 - 1);
243   }
244 
245   /// Returns the number of elements to transfer per stride for this DMA op.
getNumElementsPerStride()246   Value getNumElementsPerStride() {
247     if (!isStrided())
248       return nullptr;
249     return getOperand(getNumOperands() - 1);
250   }
251 };
252 
253 /// AffineDmaWaitOp blocks until the completion of a DMA operation associated
254 /// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be
255 /// an index with the same restrictions as any load/store index. In particular,
256 /// index for each memref dimension must be an affine expression of loop
257 /// induction variables and symbols. %num_elements is the number of elements
258 /// associated with the DMA operation. For example:
259 //
260 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements :
261 //     memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2>
262 //   ...
263 //   ...
264 //   affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
265 //
266 class AffineDmaWaitOp
267     : public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
268                 OpTrait::VariadicOperands, OpTrait::ZeroResult,
269                 AffineMapAccessInterface::Trait> {
270 public:
271   using Op::Op;
getAttributeNames()272   static ArrayRef<StringRef> getAttributeNames() { return {}; }
273 
274   static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
275                     AffineMap tagMap, ValueRange tagIndices, Value numElements);
276 
getOperationName()277   static StringRef getOperationName() { return "affine.dma_wait"; }
278 
279   /// Returns the Tag MemRef associated with the DMA operation being waited on.
getTagMemRef()280   Value getTagMemRef() { return getOperand(0); }
getTagMemRefType()281   MemRefType getTagMemRefType() {
282     return getTagMemRef().getType().cast<MemRefType>();
283   }
284 
285   /// Returns the affine map used to access the tag memref.
getTagMap()286   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
getTagMapAttr()287   AffineMapAttr getTagMapAttr() {
288     return (*this)->getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
289   }
290 
291   /// Returns the tag memref index for this DMA operation.
getTagIndices()292   operand_range getTagIndices() {
293     return {operand_begin() + 1,
294             operand_begin() + 1 + getTagMap().getNumInputs()};
295   }
296 
297   /// Returns the rank (number of indices) of the tag memref.
getTagMemRefRank()298   unsigned getTagMemRefRank() {
299     return getTagMemRef().getType().cast<MemRefType>().getRank();
300   }
301 
302   /// Impelements the AffineMapAccessInterface. Returns the AffineMapAttr
303   /// associated with 'memref'.
getAffineMapAttrForMemRef(Value memref)304   NamedAttribute getAffineMapAttrForMemRef(Value memref) {
305     assert(memref == getTagMemRef());
306     return {Identifier::get(getTagMapAttrName(), getContext()),
307             getTagMapAttr()};
308   }
309 
310   /// Returns the number of elements transferred by the associated DMA op.
getNumElements()311   Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
312 
getTagMapAttrName()313   static StringRef getTagMapAttrName() { return "tag_map"; }
314   static ParseResult parse(OpAsmParser &parser, OperationState &result);
315   void print(OpAsmPrinter &p);
316   LogicalResult verify();
317   LogicalResult fold(ArrayRef<Attribute> cstOperands,
318                      SmallVectorImpl<OpFoldResult> &results);
319 };
320 
321 /// Returns true if the given Value can be used as a dimension id in the region
322 /// of the closest surrounding op that has the trait `AffineScope`.
323 bool isValidDim(Value value);
324 
325 /// Returns true if the given Value can be used as a dimension id in `region`,
326 /// i.e., for all its uses in `region`.
327 bool isValidDim(Value value, Region *region);
328 
329 /// Returns true if the given value can be used as a symbol in the region of the
330 /// closest surrounding op that has the trait `AffineScope`.
331 bool isValidSymbol(Value value);
332 
333 /// Returns true if the given Value can be used as a symbol for `region`, i.e.,
334 /// for all its uses in `region`.
335 bool isValidSymbol(Value value, Region *region);
336 
337 /// Parses dimension and symbol list. `numDims` is set to the number of
338 /// dimensions in the list parsed.
339 ParseResult parseDimAndSymbolList(OpAsmParser &parser,
340                                   SmallVectorImpl<Value> &operands,
341                                   unsigned &numDims);
342 
343 /// Modifies both `map` and `operands` in-place so as to:
344 /// 1. drop duplicate operands
345 /// 2. drop unused dims and symbols from map
346 /// 3. promote valid symbols to symbolic operands in case they appeared as
347 ///    dimensional operands
348 /// 4. propagate constant operands and drop them
349 void canonicalizeMapAndOperands(AffineMap *map,
350                                 SmallVectorImpl<Value> *operands);
351 
352 /// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
353 /// for affine maps.
354 void canonicalizeSetAndOperands(IntegerSet *set,
355                                 SmallVectorImpl<Value> *operands);
356 
357 /// Returns a composed AffineApplyOp by composing `map` and `operands` with
358 /// other AffineApplyOps supplying those operands. The operands of the resulting
359 /// AffineApplyOp do not change the length of  AffineApplyOp chains.
360 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
361                                       ValueRange operands);
362 /// Variant of `makeComposedAffineApply` which infers the AffineMap from `e`.
363 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
364                                       ValueRange values);
365 
366 /// Given an affine map `map` and its input `operands`, this method composes
367 /// into `map`, maps of AffineApplyOps whose results are the values in
368 /// `operands`, iteratively until no more of `operands` are the result of an
369 /// AffineApplyOp. When this function returns, `map` becomes the composed affine
370 /// map, and each Value in `operands` is guaranteed to be either a loop IV or a
371 /// terminal symbol, i.e., a symbol defined at the top level or a block/function
372 /// argument.
373 void fullyComposeAffineMapAndOperands(AffineMap *map,
374                                       SmallVectorImpl<Value> *operands);
375 } // namespace mlir
376 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.h.inc"
377 
378 #define GET_OP_CLASSES
379 #include "mlir/Dialect/Affine/IR/AffineOps.h.inc"
380 
381 namespace mlir {
382 /// Returns true if the provided value is the induction variable of a
383 /// AffineForOp.
384 bool isForInductionVar(Value val);
385 
386 /// Returns the loop parent of an induction variable. If the provided value is
387 /// not an induction variable, then return nullptr.
388 AffineForOp getForInductionVarOwner(Value val);
389 
390 /// Extracts the induction variables from a list of AffineForOps and places them
391 /// in the output argument `ivs`.
392 void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
393                              SmallVectorImpl<Value> *ivs);
394 
395 /// Builds a perfect nest of affine.for loops, i.e., each loop except the
396 /// innermost one contains only another loop and a terminator. The loops iterate
397 /// from "lbs" to "ubs" with "steps". The body of the innermost loop is
398 /// populated by calling "bodyBuilderFn" and providing it with an OpBuilder, a
399 /// Location and a list of loop induction variables.
400 void buildAffineLoopNest(OpBuilder &builder, Location loc,
401                          ArrayRef<int64_t> lbs, ArrayRef<int64_t> ubs,
402                          ArrayRef<int64_t> steps,
403                          function_ref<void(OpBuilder &, Location, ValueRange)>
404                              bodyBuilderFn = nullptr);
405 void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
406                          ValueRange ubs, ArrayRef<int64_t> steps,
407                          function_ref<void(OpBuilder &, Location, ValueRange)>
408                              bodyBuilderFn = nullptr);
409 
410 /// Replace `loop` with a new loop where `newIterOperands` are appended with
411 /// new initialization values and `newYieldedValues` are added as new yielded
412 /// values. The returned ForOp has `newYieldedValues.size()` new result values.
413 /// Additionally, if `replaceLoopResults` is true, all uses of
414 /// `loop.getResults()` are replaced with the first `loop.getNumResults()`
415 /// return values  of the original loop respectively. The original loop is
416 /// deleted and the new loop returned.
417 /// Prerequisite: `newIterOperands.size() == newYieldedValues.size()`.
418 AffineForOp replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
419                                       ValueRange newIterOperands,
420                                       ValueRange newYieldedValues,
421                                       ValueRange newIterArgs,
422                                       bool replaceLoopResults = true);
423 
424 /// AffineBound represents a lower or upper bound in the for operation.
425 /// This class does not own the underlying operands. Instead, it refers
426 /// to the operands stored in the AffineForOp. Its life span should not exceed
427 /// that of the for operation it refers to.
428 class AffineBound {
429 public:
getAffineForOp()430   AffineForOp getAffineForOp() { return op; }
getMap()431   AffineMap getMap() { return map; }
432 
getNumOperands()433   unsigned getNumOperands() { return opEnd - opStart; }
getOperand(unsigned idx)434   Value getOperand(unsigned idx) { return op.getOperand(opStart + idx); }
435 
436   using operand_iterator = AffineForOp::operand_iterator;
437   using operand_range = AffineForOp::operand_range;
438 
operand_begin()439   operand_iterator operand_begin() { return op.operand_begin() + opStart; }
operand_end()440   operand_iterator operand_end() { return op.operand_begin() + opEnd; }
getOperands()441   operand_range getOperands() { return {operand_begin(), operand_end()}; }
442 
443 private:
444   // 'affine.for' operation that contains this bound.
445   AffineForOp op;
446   // Start and end positions of this affine bound operands in the list of
447   // the containing 'affine.for' operation operands.
448   unsigned opStart, opEnd;
449   // Affine map for this bound.
450   AffineMap map;
451 
AffineBound(AffineForOp op,unsigned opStart,unsigned opEnd,AffineMap map)452   AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map)
453       : op(op), opStart(opStart), opEnd(opEnd), map(map) {}
454 
455   friend class AffineForOp;
456 };
457 
458 } // end namespace mlir
459 
460 #endif
461