1 //===- AffineOps.h - MLIR Affine Operations -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines convenience types for working with Affine operations 10 // in the MLIR operation set. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H 15 #define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H 16 17 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/Interfaces/LoopLikeInterface.h" 21 22 namespace mlir { 23 class AffineApplyOp; 24 class AffineBound; 25 class AffineValueMap; 26 27 /// A utility function to check if a value is defined at the top level of an 28 /// op with trait `AffineScope` or is a region argument for such an op. A value 29 /// of index type defined at the top level is always a valid symbol for all its 30 /// uses. 31 bool isTopLevelValue(Value value); 32 33 /// AffineDmaStartOp starts a non-blocking DMA operation that transfers data 34 /// from a source memref to a destination memref. The source and destination 35 /// memref need not be of the same dimensionality, but need to have the same 36 /// elemental type. The operands include the source and destination memref's 37 /// each followed by its indices, size of the data transfer in terms of the 38 /// number of elements (of the elemental type of the memref), a tag memref with 39 /// its indices, and optionally at the end, a stride and a 40 /// number_of_elements_per_stride arguments. The tag location is used by an 41 /// AffineDmaWaitOp to check for completion. The indices of the source memref, 42 /// destination memref, and the tag memref have the same restrictions as any 43 /// affine.load/store. In particular, index for each memref dimension must be an 44 /// affine expression of loop induction variables and symbols. 45 /// The optional stride arguments should be of 'index' type, and specify a 46 /// stride for the slower memory space (memory space with a lower memory space 47 /// id), transferring chunks of number_of_elements_per_stride every stride until 48 /// %num_elements are transferred. Either both or no stride arguments should be 49 /// specified. The value of 'num_elements' must be a multiple of 50 /// 'number_of_elements_per_stride'. If the source and destination locations 51 /// overlap the behavior of this operation is not defined. 52 // 53 // For example, an AffineDmaStartOp operation that transfers 256 elements of a 54 // memref '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in 55 // memory space 1 at indices [%k + 7, %l], would be specified as follows: 56 // 57 // %num_elements = constant 256 58 // %idx = constant 0 : index 59 // %tag = alloc() : memref<1xi32, 4> 60 // affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], 61 // %num_elements : 62 // memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> 63 // 64 // If %stride and %num_elt_per_stride are specified, the DMA is expected to 65 // transfer %num_elt_per_stride elements every %stride elements apart from 66 // memory space 0 until %num_elements are transferred. 67 // 68 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements, 69 // %stride, %num_elt_per_stride : ... 70 // 71 // TODO: add additional operands to allow source and destination striding, and 72 // multiple stride levels (possibly using AffineMaps to specify multiple levels 73 // of striding). 74 class AffineDmaStartOp 75 : public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable, 76 OpTrait::VariadicOperands, OpTrait::ZeroResult, 77 AffineMapAccessInterface::Trait> { 78 public: 79 using Op::Op; getAttributeNames()80 static ArrayRef<StringRef> getAttributeNames() { return {}; } 81 82 static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, 83 AffineMap srcMap, ValueRange srcIndices, Value destMemRef, 84 AffineMap dstMap, ValueRange destIndices, Value tagMemRef, 85 AffineMap tagMap, ValueRange tagIndices, Value numElements, 86 Value stride = nullptr, Value elementsPerStride = nullptr); 87 88 /// Returns the operand index of the source memref. getSrcMemRefOperandIndex()89 unsigned getSrcMemRefOperandIndex() { return 0; } 90 91 /// Returns the source MemRefType for this DMA operation. getSrcMemRef()92 Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); } getSrcMemRefType()93 MemRefType getSrcMemRefType() { 94 return getSrcMemRef().getType().cast<MemRefType>(); 95 } 96 97 /// Returns the rank (number of indices) of the source MemRefType. getSrcMemRefRank()98 unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); } 99 100 /// Returns the affine map used to access the source memref. getSrcMap()101 AffineMap getSrcMap() { return getSrcMapAttr().getValue(); } getSrcMapAttr()102 AffineMapAttr getSrcMapAttr() { 103 return (*this)->getAttr(getSrcMapAttrName()).cast<AffineMapAttr>(); 104 } 105 106 /// Returns the source memref affine map indices for this DMA operation. getSrcIndices()107 operand_range getSrcIndices() { 108 return {operand_begin() + getSrcMemRefOperandIndex() + 1, 109 operand_begin() + getSrcMemRefOperandIndex() + 1 + 110 getSrcMap().getNumInputs()}; 111 } 112 113 /// Returns the memory space of the source memref. getSrcMemorySpace()114 unsigned getSrcMemorySpace() { 115 return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt(); 116 } 117 118 /// Returns the operand index of the destination memref. getDstMemRefOperandIndex()119 unsigned getDstMemRefOperandIndex() { 120 return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs(); 121 } 122 123 /// Returns the destination MemRefType for this DMA operation. getDstMemRef()124 Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } getDstMemRefType()125 MemRefType getDstMemRefType() { 126 return getDstMemRef().getType().cast<MemRefType>(); 127 } 128 129 /// Returns the rank (number of indices) of the destination MemRefType. getDstMemRefRank()130 unsigned getDstMemRefRank() { 131 return getDstMemRef().getType().cast<MemRefType>().getRank(); 132 } 133 134 /// Returns the memory space of the source memref. getDstMemorySpace()135 unsigned getDstMemorySpace() { 136 return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt(); 137 } 138 139 /// Returns the affine map used to access the destination memref. getDstMap()140 AffineMap getDstMap() { return getDstMapAttr().getValue(); } getDstMapAttr()141 AffineMapAttr getDstMapAttr() { 142 return (*this)->getAttr(getDstMapAttrName()).cast<AffineMapAttr>(); 143 } 144 145 /// Returns the destination memref indices for this DMA operation. getDstIndices()146 operand_range getDstIndices() { 147 return {operand_begin() + getDstMemRefOperandIndex() + 1, 148 operand_begin() + getDstMemRefOperandIndex() + 1 + 149 getDstMap().getNumInputs()}; 150 } 151 152 /// Returns the operand index of the tag memref. getTagMemRefOperandIndex()153 unsigned getTagMemRefOperandIndex() { 154 return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs(); 155 } 156 157 /// Returns the Tag MemRef for this DMA operation. getTagMemRef()158 Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } getTagMemRefType()159 MemRefType getTagMemRefType() { 160 return getTagMemRef().getType().cast<MemRefType>(); 161 } 162 163 /// Returns the rank (number of indices) of the tag MemRefType. getTagMemRefRank()164 unsigned getTagMemRefRank() { 165 return getTagMemRef().getType().cast<MemRefType>().getRank(); 166 } 167 168 /// Returns the affine map used to access the tag memref. getTagMap()169 AffineMap getTagMap() { return getTagMapAttr().getValue(); } getTagMapAttr()170 AffineMapAttr getTagMapAttr() { 171 return (*this)->getAttr(getTagMapAttrName()).cast<AffineMapAttr>(); 172 } 173 174 /// Returns the tag memref indices for this DMA operation. getTagIndices()175 operand_range getTagIndices() { 176 return {operand_begin() + getTagMemRefOperandIndex() + 1, 177 operand_begin() + getTagMemRefOperandIndex() + 1 + 178 getTagMap().getNumInputs()}; 179 } 180 181 /// Returns the number of elements being transferred by this DMA operation. getNumElements()182 Value getNumElements() { 183 return getOperand(getTagMemRefOperandIndex() + 1 + 184 getTagMap().getNumInputs()); 185 } 186 187 /// Impelements the AffineMapAccessInterface. 188 /// Returns the AffineMapAttr associated with 'memref'. getAffineMapAttrForMemRef(Value memref)189 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 190 if (memref == getSrcMemRef()) 191 return {Identifier::get(getSrcMapAttrName(), getContext()), 192 getSrcMapAttr()}; 193 if (memref == getDstMemRef()) 194 return {Identifier::get(getDstMapAttrName(), getContext()), 195 getDstMapAttr()}; 196 assert(memref == getTagMemRef() && 197 "DmaStartOp expected source, destination or tag memref"); 198 return {Identifier::get(getTagMapAttrName(), getContext()), 199 getTagMapAttr()}; 200 } 201 202 /// Returns true if this is a DMA from a faster memory space to a slower one. isDestMemorySpaceFaster()203 bool isDestMemorySpaceFaster() { 204 return (getSrcMemorySpace() < getDstMemorySpace()); 205 } 206 207 /// Returns true if this is a DMA from a slower memory space to a faster one. isSrcMemorySpaceFaster()208 bool isSrcMemorySpaceFaster() { 209 // Assumes that a lower number is for a slower memory space. 210 return (getDstMemorySpace() < getSrcMemorySpace()); 211 } 212 213 /// Given a DMA start operation, returns the operand position of either the 214 /// source or destination memref depending on the one that is at the higher 215 /// level of the memory hierarchy. Asserts failure if neither is true. getFasterMemPos()216 unsigned getFasterMemPos() { 217 assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); 218 return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex(); 219 } 220 getSrcMapAttrName()221 static StringRef getSrcMapAttrName() { return "src_map"; } getDstMapAttrName()222 static StringRef getDstMapAttrName() { return "dst_map"; } getTagMapAttrName()223 static StringRef getTagMapAttrName() { return "tag_map"; } 224 getOperationName()225 static StringRef getOperationName() { return "affine.dma_start"; } 226 static ParseResult parse(OpAsmParser &parser, OperationState &result); 227 void print(OpAsmPrinter &p); 228 LogicalResult verify(); 229 LogicalResult fold(ArrayRef<Attribute> cstOperands, 230 SmallVectorImpl<OpFoldResult> &results); 231 232 /// Returns true if this DMA operation is strided, returns false otherwise. isStrided()233 bool isStrided() { 234 return getNumOperands() != 235 getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1; 236 } 237 238 /// Returns the stride value for this DMA operation. getStride()239 Value getStride() { 240 if (!isStrided()) 241 return nullptr; 242 return getOperand(getNumOperands() - 1 - 1); 243 } 244 245 /// Returns the number of elements to transfer per stride for this DMA op. getNumElementsPerStride()246 Value getNumElementsPerStride() { 247 if (!isStrided()) 248 return nullptr; 249 return getOperand(getNumOperands() - 1); 250 } 251 }; 252 253 /// AffineDmaWaitOp blocks until the completion of a DMA operation associated 254 /// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be 255 /// an index with the same restrictions as any load/store index. In particular, 256 /// index for each memref dimension must be an affine expression of loop 257 /// induction variables and symbols. %num_elements is the number of elements 258 /// associated with the DMA operation. For example: 259 // 260 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements : 261 // memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2> 262 // ... 263 // ... 264 // affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> 265 // 266 class AffineDmaWaitOp 267 : public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable, 268 OpTrait::VariadicOperands, OpTrait::ZeroResult, 269 AffineMapAccessInterface::Trait> { 270 public: 271 using Op::Op; getAttributeNames()272 static ArrayRef<StringRef> getAttributeNames() { return {}; } 273 274 static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, 275 AffineMap tagMap, ValueRange tagIndices, Value numElements); 276 getOperationName()277 static StringRef getOperationName() { return "affine.dma_wait"; } 278 279 /// Returns the Tag MemRef associated with the DMA operation being waited on. getTagMemRef()280 Value getTagMemRef() { return getOperand(0); } getTagMemRefType()281 MemRefType getTagMemRefType() { 282 return getTagMemRef().getType().cast<MemRefType>(); 283 } 284 285 /// Returns the affine map used to access the tag memref. getTagMap()286 AffineMap getTagMap() { return getTagMapAttr().getValue(); } getTagMapAttr()287 AffineMapAttr getTagMapAttr() { 288 return (*this)->getAttr(getTagMapAttrName()).cast<AffineMapAttr>(); 289 } 290 291 /// Returns the tag memref index for this DMA operation. getTagIndices()292 operand_range getTagIndices() { 293 return {operand_begin() + 1, 294 operand_begin() + 1 + getTagMap().getNumInputs()}; 295 } 296 297 /// Returns the rank (number of indices) of the tag memref. getTagMemRefRank()298 unsigned getTagMemRefRank() { 299 return getTagMemRef().getType().cast<MemRefType>().getRank(); 300 } 301 302 /// Impelements the AffineMapAccessInterface. Returns the AffineMapAttr 303 /// associated with 'memref'. getAffineMapAttrForMemRef(Value memref)304 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 305 assert(memref == getTagMemRef()); 306 return {Identifier::get(getTagMapAttrName(), getContext()), 307 getTagMapAttr()}; 308 } 309 310 /// Returns the number of elements transferred by the associated DMA op. getNumElements()311 Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } 312 getTagMapAttrName()313 static StringRef getTagMapAttrName() { return "tag_map"; } 314 static ParseResult parse(OpAsmParser &parser, OperationState &result); 315 void print(OpAsmPrinter &p); 316 LogicalResult verify(); 317 LogicalResult fold(ArrayRef<Attribute> cstOperands, 318 SmallVectorImpl<OpFoldResult> &results); 319 }; 320 321 /// Returns true if the given Value can be used as a dimension id in the region 322 /// of the closest surrounding op that has the trait `AffineScope`. 323 bool isValidDim(Value value); 324 325 /// Returns true if the given Value can be used as a dimension id in `region`, 326 /// i.e., for all its uses in `region`. 327 bool isValidDim(Value value, Region *region); 328 329 /// Returns true if the given value can be used as a symbol in the region of the 330 /// closest surrounding op that has the trait `AffineScope`. 331 bool isValidSymbol(Value value); 332 333 /// Returns true if the given Value can be used as a symbol for `region`, i.e., 334 /// for all its uses in `region`. 335 bool isValidSymbol(Value value, Region *region); 336 337 /// Parses dimension and symbol list. `numDims` is set to the number of 338 /// dimensions in the list parsed. 339 ParseResult parseDimAndSymbolList(OpAsmParser &parser, 340 SmallVectorImpl<Value> &operands, 341 unsigned &numDims); 342 343 /// Modifies both `map` and `operands` in-place so as to: 344 /// 1. drop duplicate operands 345 /// 2. drop unused dims and symbols from map 346 /// 3. promote valid symbols to symbolic operands in case they appeared as 347 /// dimensional operands 348 /// 4. propagate constant operands and drop them 349 void canonicalizeMapAndOperands(AffineMap *map, 350 SmallVectorImpl<Value> *operands); 351 352 /// Canonicalizes an integer set the same way canonicalizeMapAndOperands does 353 /// for affine maps. 354 void canonicalizeSetAndOperands(IntegerSet *set, 355 SmallVectorImpl<Value> *operands); 356 357 /// Returns a composed AffineApplyOp by composing `map` and `operands` with 358 /// other AffineApplyOps supplying those operands. The operands of the resulting 359 /// AffineApplyOp do not change the length of AffineApplyOp chains. 360 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, 361 ValueRange operands); 362 /// Variant of `makeComposedAffineApply` which infers the AffineMap from `e`. 363 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e, 364 ValueRange values); 365 366 /// Given an affine map `map` and its input `operands`, this method composes 367 /// into `map`, maps of AffineApplyOps whose results are the values in 368 /// `operands`, iteratively until no more of `operands` are the result of an 369 /// AffineApplyOp. When this function returns, `map` becomes the composed affine 370 /// map, and each Value in `operands` is guaranteed to be either a loop IV or a 371 /// terminal symbol, i.e., a symbol defined at the top level or a block/function 372 /// argument. 373 void fullyComposeAffineMapAndOperands(AffineMap *map, 374 SmallVectorImpl<Value> *operands); 375 } // namespace mlir 376 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.h.inc" 377 378 #define GET_OP_CLASSES 379 #include "mlir/Dialect/Affine/IR/AffineOps.h.inc" 380 381 namespace mlir { 382 /// Returns true if the provided value is the induction variable of a 383 /// AffineForOp. 384 bool isForInductionVar(Value val); 385 386 /// Returns the loop parent of an induction variable. If the provided value is 387 /// not an induction variable, then return nullptr. 388 AffineForOp getForInductionVarOwner(Value val); 389 390 /// Extracts the induction variables from a list of AffineForOps and places them 391 /// in the output argument `ivs`. 392 void extractForInductionVars(ArrayRef<AffineForOp> forInsts, 393 SmallVectorImpl<Value> *ivs); 394 395 /// Builds a perfect nest of affine.for loops, i.e., each loop except the 396 /// innermost one contains only another loop and a terminator. The loops iterate 397 /// from "lbs" to "ubs" with "steps". The body of the innermost loop is 398 /// populated by calling "bodyBuilderFn" and providing it with an OpBuilder, a 399 /// Location and a list of loop induction variables. 400 void buildAffineLoopNest(OpBuilder &builder, Location loc, 401 ArrayRef<int64_t> lbs, ArrayRef<int64_t> ubs, 402 ArrayRef<int64_t> steps, 403 function_ref<void(OpBuilder &, Location, ValueRange)> 404 bodyBuilderFn = nullptr); 405 void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, 406 ValueRange ubs, ArrayRef<int64_t> steps, 407 function_ref<void(OpBuilder &, Location, ValueRange)> 408 bodyBuilderFn = nullptr); 409 410 /// Replace `loop` with a new loop where `newIterOperands` are appended with 411 /// new initialization values and `newYieldedValues` are added as new yielded 412 /// values. The returned ForOp has `newYieldedValues.size()` new result values. 413 /// Additionally, if `replaceLoopResults` is true, all uses of 414 /// `loop.getResults()` are replaced with the first `loop.getNumResults()` 415 /// return values of the original loop respectively. The original loop is 416 /// deleted and the new loop returned. 417 /// Prerequisite: `newIterOperands.size() == newYieldedValues.size()`. 418 AffineForOp replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop, 419 ValueRange newIterOperands, 420 ValueRange newYieldedValues, 421 ValueRange newIterArgs, 422 bool replaceLoopResults = true); 423 424 /// AffineBound represents a lower or upper bound in the for operation. 425 /// This class does not own the underlying operands. Instead, it refers 426 /// to the operands stored in the AffineForOp. Its life span should not exceed 427 /// that of the for operation it refers to. 428 class AffineBound { 429 public: getAffineForOp()430 AffineForOp getAffineForOp() { return op; } getMap()431 AffineMap getMap() { return map; } 432 getNumOperands()433 unsigned getNumOperands() { return opEnd - opStart; } getOperand(unsigned idx)434 Value getOperand(unsigned idx) { return op.getOperand(opStart + idx); } 435 436 using operand_iterator = AffineForOp::operand_iterator; 437 using operand_range = AffineForOp::operand_range; 438 operand_begin()439 operand_iterator operand_begin() { return op.operand_begin() + opStart; } operand_end()440 operand_iterator operand_end() { return op.operand_begin() + opEnd; } getOperands()441 operand_range getOperands() { return {operand_begin(), operand_end()}; } 442 443 private: 444 // 'affine.for' operation that contains this bound. 445 AffineForOp op; 446 // Start and end positions of this affine bound operands in the list of 447 // the containing 'affine.for' operation operands. 448 unsigned opStart, opEnd; 449 // Affine map for this bound. 450 AffineMap map; 451 AffineBound(AffineForOp op,unsigned opStart,unsigned opEnd,AffineMap map)452 AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map) 453 : op(op), opStart(opStart), opEnd(opEnd), map(map) {} 454 455 friend class AffineForOp; 456 }; 457 458 } // end namespace mlir 459 460 #endif 461