1 //===- AffineOps.h - MLIR Affine Operations -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines convenience types for working with Affine operations 10 // in the MLIR operation set. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H 15 #define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H 16 17 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/IR/OpDefinition.h" 23 #include "mlir/IR/StandardTypes.h" 24 #include "mlir/Interfaces/LoopLikeInterface.h" 25 #include "mlir/Interfaces/SideEffectInterfaces.h" 26 27 namespace mlir { 28 class AffineApplyOp; 29 class AffineBound; 30 class AffineDimExpr; 31 class AffineValueMap; 32 class AffineYieldOp; 33 class FlatAffineConstraints; 34 class OpBuilder; 35 36 /// A utility function to check if a value is defined at the top level of an 37 /// op with trait `AffineScope` or is a region argument for such an op. A value 38 /// of index type defined at the top level is always a valid symbol for all its 39 /// uses. 40 bool isTopLevelValue(Value value); 41 42 /// AffineDmaStartOp starts a non-blocking DMA operation that transfers data 43 /// from a source memref to a destination memref. The source and destination 44 /// memref need not be of the same dimensionality, but need to have the same 45 /// elemental type. The operands include the source and destination memref's 46 /// each followed by its indices, size of the data transfer in terms of the 47 /// number of elements (of the elemental type of the memref), a tag memref with 48 /// its indices, and optionally at the end, a stride and a 49 /// number_of_elements_per_stride arguments. The tag location is used by an 50 /// AffineDmaWaitOp to check for completion. The indices of the source memref, 51 /// destination memref, and the tag memref have the same restrictions as any 52 /// affine.load/store. In particular, index for each memref dimension must be an 53 /// affine expression of loop induction variables and symbols. 54 /// The optional stride arguments should be of 'index' type, and specify a 55 /// stride for the slower memory space (memory space with a lower memory space 56 /// id), transferring chunks of number_of_elements_per_stride every stride until 57 /// %num_elements are transferred. Either both or no stride arguments should be 58 /// specified. The value of 'num_elements' must be a multiple of 59 /// 'number_of_elements_per_stride'. 60 // 61 // For example, a DmaStartOp operation that transfers 256 elements of a memref 62 // '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory 63 // space 1 at indices [%k + 7, %l], would be specified as follows: 64 // 65 // %num_elements = constant 256 66 // %idx = constant 0 : index 67 // %tag = alloc() : memref<1xi32, 4> 68 // affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], 69 // %num_elements : 70 // memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> 71 // 72 // If %stride and %num_elt_per_stride are specified, the DMA is expected to 73 // transfer %num_elt_per_stride elements every %stride elements apart from 74 // memory space 0 until %num_elements are transferred. 75 // 76 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements, 77 // %stride, %num_elt_per_stride : ... 78 // 79 // TODO: add additional operands to allow source and destination striding, and 80 // multiple stride levels (possibly using AffineMaps to specify multiple levels 81 // of striding). 82 // TODO: Consider replacing src/dst memref indices with view memrefs. 83 class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands, 84 OpTrait::ZeroResult> { 85 public: 86 using Op::Op; 87 88 static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, 89 AffineMap srcMap, ValueRange srcIndices, Value destMemRef, 90 AffineMap dstMap, ValueRange destIndices, Value tagMemRef, 91 AffineMap tagMap, ValueRange tagIndices, Value numElements, 92 Value stride = nullptr, Value elementsPerStride = nullptr); 93 94 /// Returns the operand index of the src memref. getSrcMemRefOperandIndex()95 unsigned getSrcMemRefOperandIndex() { return 0; } 96 97 /// Returns the source MemRefType for this DMA operation. getSrcMemRef()98 Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); } getSrcMemRefType()99 MemRefType getSrcMemRefType() { 100 return getSrcMemRef().getType().cast<MemRefType>(); 101 } 102 103 /// Returns the rank (number of indices) of the source MemRefType. getSrcMemRefRank()104 unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); } 105 106 /// Returns the affine map used to access the src memref. getSrcMap()107 AffineMap getSrcMap() { return getSrcMapAttr().getValue(); } getSrcMapAttr()108 AffineMapAttr getSrcMapAttr() { 109 return getAttr(getSrcMapAttrName()).cast<AffineMapAttr>(); 110 } 111 112 /// Returns the source memref affine map indices for this DMA operation. getSrcIndices()113 operand_range getSrcIndices() { 114 return {operand_begin() + getSrcMemRefOperandIndex() + 1, 115 operand_begin() + getSrcMemRefOperandIndex() + 1 + 116 getSrcMap().getNumInputs()}; 117 } 118 119 /// Returns the memory space of the src memref. getSrcMemorySpace()120 unsigned getSrcMemorySpace() { 121 return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace(); 122 } 123 124 /// Returns the operand index of the dst memref. getDstMemRefOperandIndex()125 unsigned getDstMemRefOperandIndex() { 126 return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs(); 127 } 128 129 /// Returns the destination MemRefType for this DMA operations. getDstMemRef()130 Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } getDstMemRefType()131 MemRefType getDstMemRefType() { 132 return getDstMemRef().getType().cast<MemRefType>(); 133 } 134 135 /// Returns the rank (number of indices) of the destination MemRefType. getDstMemRefRank()136 unsigned getDstMemRefRank() { 137 return getDstMemRef().getType().cast<MemRefType>().getRank(); 138 } 139 140 /// Returns the memory space of the src memref. getDstMemorySpace()141 unsigned getDstMemorySpace() { 142 return getDstMemRef().getType().cast<MemRefType>().getMemorySpace(); 143 } 144 145 /// Returns the affine map used to access the dst memref. getDstMap()146 AffineMap getDstMap() { return getDstMapAttr().getValue(); } getDstMapAttr()147 AffineMapAttr getDstMapAttr() { 148 return getAttr(getDstMapAttrName()).cast<AffineMapAttr>(); 149 } 150 151 /// Returns the destination memref indices for this DMA operation. getDstIndices()152 operand_range getDstIndices() { 153 return {operand_begin() + getDstMemRefOperandIndex() + 1, 154 operand_begin() + getDstMemRefOperandIndex() + 1 + 155 getDstMap().getNumInputs()}; 156 } 157 158 /// Returns the operand index of the tag memref. getTagMemRefOperandIndex()159 unsigned getTagMemRefOperandIndex() { 160 return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs(); 161 } 162 163 /// Returns the Tag MemRef for this DMA operation. getTagMemRef()164 Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } getTagMemRefType()165 MemRefType getTagMemRefType() { 166 return getTagMemRef().getType().cast<MemRefType>(); 167 } 168 169 /// Returns the rank (number of indices) of the tag MemRefType. getTagMemRefRank()170 unsigned getTagMemRefRank() { 171 return getTagMemRef().getType().cast<MemRefType>().getRank(); 172 } 173 174 /// Returns the affine map used to access the tag memref. getTagMap()175 AffineMap getTagMap() { return getTagMapAttr().getValue(); } getTagMapAttr()176 AffineMapAttr getTagMapAttr() { 177 return getAttr(getTagMapAttrName()).cast<AffineMapAttr>(); 178 } 179 180 /// Returns the tag memref indices for this DMA operation. getTagIndices()181 operand_range getTagIndices() { 182 return {operand_begin() + getTagMemRefOperandIndex() + 1, 183 operand_begin() + getTagMemRefOperandIndex() + 1 + 184 getTagMap().getNumInputs()}; 185 } 186 187 /// Returns the number of elements being transferred by this DMA operation. getNumElements()188 Value getNumElements() { 189 return getOperand(getTagMemRefOperandIndex() + 1 + 190 getTagMap().getNumInputs()); 191 } 192 193 /// Returns the AffineMapAttr associated with 'memref'. getAffineMapAttrForMemRef(Value memref)194 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 195 if (memref == getSrcMemRef()) 196 return {Identifier::get(getSrcMapAttrName(), getContext()), 197 getSrcMapAttr()}; 198 else if (memref == getDstMemRef()) 199 return {Identifier::get(getDstMapAttrName(), getContext()), 200 getDstMapAttr()}; 201 assert(memref == getTagMemRef() && 202 "DmaStartOp expected source, destination or tag memref"); 203 return {Identifier::get(getTagMapAttrName(), getContext()), 204 getTagMapAttr()}; 205 } 206 207 /// Returns true if this is a DMA from a faster memory space to a slower one. isDestMemorySpaceFaster()208 bool isDestMemorySpaceFaster() { 209 return (getSrcMemorySpace() < getDstMemorySpace()); 210 } 211 212 /// Returns true if this is a DMA from a slower memory space to a faster one. isSrcMemorySpaceFaster()213 bool isSrcMemorySpaceFaster() { 214 // Assumes that a lower number is for a slower memory space. 215 return (getDstMemorySpace() < getSrcMemorySpace()); 216 } 217 218 /// Given a DMA start operation, returns the operand position of either the 219 /// source or destination memref depending on the one that is at the higher 220 /// level of the memory hierarchy. Asserts failure if neither is true. getFasterMemPos()221 unsigned getFasterMemPos() { 222 assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); 223 return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex(); 224 } 225 getSrcMapAttrName()226 static StringRef getSrcMapAttrName() { return "src_map"; } getDstMapAttrName()227 static StringRef getDstMapAttrName() { return "dst_map"; } getTagMapAttrName()228 static StringRef getTagMapAttrName() { return "tag_map"; } 229 getOperationName()230 static StringRef getOperationName() { return "affine.dma_start"; } 231 static ParseResult parse(OpAsmParser &parser, OperationState &result); 232 void print(OpAsmPrinter &p); 233 LogicalResult verify(); 234 LogicalResult fold(ArrayRef<Attribute> cstOperands, 235 SmallVectorImpl<OpFoldResult> &results); 236 237 /// Returns true if this DMA operation is strided, returns false otherwise. isStrided()238 bool isStrided() { 239 return getNumOperands() != 240 getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1; 241 } 242 243 /// Returns the stride value for this DMA operation. getStride()244 Value getStride() { 245 if (!isStrided()) 246 return nullptr; 247 return getOperand(getNumOperands() - 1 - 1); 248 } 249 250 /// Returns the number of elements to transfer per stride for this DMA op. getNumElementsPerStride()251 Value getNumElementsPerStride() { 252 if (!isStrided()) 253 return nullptr; 254 return getOperand(getNumOperands() - 1); 255 } 256 }; 257 258 /// AffineDmaWaitOp blocks until the completion of a DMA operation associated 259 /// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be 260 /// an index with the same restrictions as any load/store index. In particular, 261 /// index for each memref dimension must be an affine expression of loop 262 /// induction variables and symbols. %num_elements is the number of elements 263 /// associated with the DMA operation. For example: 264 // 265 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements : 266 // memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2> 267 // ... 268 // ... 269 // affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> 270 // 271 class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands, 272 OpTrait::ZeroResult> { 273 public: 274 using Op::Op; 275 276 static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, 277 AffineMap tagMap, ValueRange tagIndices, Value numElements); 278 getOperationName()279 static StringRef getOperationName() { return "affine.dma_wait"; } 280 281 // Returns the Tag MemRef associated with the DMA operation being waited on. getTagMemRef()282 Value getTagMemRef() { return getOperand(0); } getTagMemRefType()283 MemRefType getTagMemRefType() { 284 return getTagMemRef().getType().cast<MemRefType>(); 285 } 286 287 /// Returns the affine map used to access the tag memref. getTagMap()288 AffineMap getTagMap() { return getTagMapAttr().getValue(); } getTagMapAttr()289 AffineMapAttr getTagMapAttr() { 290 return getAttr(getTagMapAttrName()).cast<AffineMapAttr>(); 291 } 292 293 // Returns the tag memref index for this DMA operation. getTagIndices()294 operand_range getTagIndices() { 295 return {operand_begin() + 1, 296 operand_begin() + 1 + getTagMap().getNumInputs()}; 297 } 298 299 // Returns the rank (number of indices) of the tag memref. getTagMemRefRank()300 unsigned getTagMemRefRank() { 301 return getTagMemRef().getType().cast<MemRefType>().getRank(); 302 } 303 304 /// Returns the AffineMapAttr associated with 'memref'. getAffineMapAttrForMemRef(Value memref)305 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 306 assert(memref == getTagMemRef()); 307 return {Identifier::get(getTagMapAttrName(), getContext()), 308 getTagMapAttr()}; 309 } 310 311 /// Returns the number of elements transferred in the associated DMA op. getNumElements()312 Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } 313 getTagMapAttrName()314 static StringRef getTagMapAttrName() { return "tag_map"; } 315 static ParseResult parse(OpAsmParser &parser, OperationState &result); 316 void print(OpAsmPrinter &p); 317 LogicalResult verify(); 318 LogicalResult fold(ArrayRef<Attribute> cstOperands, 319 SmallVectorImpl<OpFoldResult> &results); 320 }; 321 322 /// Returns true if the given Value can be used as a dimension id in the region 323 /// of the closest surrounding op that has the trait `AffineScope`. 324 bool isValidDim(Value value); 325 326 /// Returns true if the given Value can be used as a dimension id in `region`, 327 /// i.e., for all its uses in `region`. 328 bool isValidDim(Value value, Region *region); 329 330 /// Returns true if the given value can be used as a symbol in the region of the 331 /// closest surrounding op that has the trait `AffineScope`. 332 bool isValidSymbol(Value value); 333 334 /// Returns true if the given Value can be used as a symbol for `region`, i.e., 335 /// for all its uses in `region`. 336 bool isValidSymbol(Value value, Region *region); 337 338 /// Modifies both `map` and `operands` in-place so as to: 339 /// 1. drop duplicate operands 340 /// 2. drop unused dims and symbols from map 341 /// 3. promote valid symbols to symbolic operands in case they appeared as 342 /// dimensional operands 343 /// 4. propagate constant operands and drop them 344 void canonicalizeMapAndOperands(AffineMap *map, 345 SmallVectorImpl<Value> *operands); 346 347 /// Canonicalizes an integer set the same way canonicalizeMapAndOperands does 348 /// for affine maps. 349 void canonicalizeSetAndOperands(IntegerSet *set, 350 SmallVectorImpl<Value> *operands); 351 352 /// Returns a composed AffineApplyOp by composing `map` and `operands` with 353 /// other AffineApplyOps supplying those operands. The operands of the resulting 354 /// AffineApplyOp do not change the length of AffineApplyOp chains. 355 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, 356 ArrayRef<Value> operands); 357 358 /// Given an affine map `map` and its input `operands`, this method composes 359 /// into `map`, maps of AffineApplyOps whose results are the values in 360 /// `operands`, iteratively until no more of `operands` are the result of an 361 /// AffineApplyOp. When this function returns, `map` becomes the composed affine 362 /// map, and each Value in `operands` is guaranteed to be either a loop IV or a 363 /// terminal symbol, i.e., a symbol defined at the top level or a block/function 364 /// argument. 365 void fullyComposeAffineMapAndOperands(AffineMap *map, 366 SmallVectorImpl<Value> *operands); 367 368 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.h.inc" 369 370 #define GET_OP_CLASSES 371 #include "mlir/Dialect/Affine/IR/AffineOps.h.inc" 372 373 /// Returns if the provided value is the induction variable of a AffineForOp. 374 bool isForInductionVar(Value val); 375 376 /// Returns the loop parent of an induction variable. If the provided value is 377 /// not an induction variable, then return nullptr. 378 AffineForOp getForInductionVarOwner(Value val); 379 380 /// Extracts the induction variables from a list of AffineForOps and places them 381 /// in the output argument `ivs`. 382 void extractForInductionVars(ArrayRef<AffineForOp> forInsts, 383 SmallVectorImpl<Value> *ivs); 384 385 /// Builds a perfect nest of affine "for" loops, i.e. each loop except the 386 /// innermost only contains another loop and a terminator. The loops iterate 387 /// from "lbs" to "ubs" with "steps". The body of the innermost loop is 388 /// populated by calling "bodyBuilderFn" and providing it with an OpBuilder, a 389 /// Location and a list of loop induction variables. 390 void buildAffineLoopNest(OpBuilder &builder, Location loc, 391 ArrayRef<int64_t> lbs, ArrayRef<int64_t> ubs, 392 ArrayRef<int64_t> steps, 393 function_ref<void(OpBuilder &, Location, ValueRange)> 394 bodyBuilderFn = nullptr); 395 void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, 396 ValueRange ubs, ArrayRef<int64_t> steps, 397 function_ref<void(OpBuilder &, Location, ValueRange)> 398 bodyBuilderFn = nullptr); 399 400 /// AffineBound represents a lower or upper bound in the for operation. 401 /// This class does not own the underlying operands. Instead, it refers 402 /// to the operands stored in the AffineForOp. Its life span should not exceed 403 /// that of the for operation it refers to. 404 class AffineBound { 405 public: getAffineForOp()406 AffineForOp getAffineForOp() { return op; } getMap()407 AffineMap getMap() { return map; } 408 getNumOperands()409 unsigned getNumOperands() { return opEnd - opStart; } getOperand(unsigned idx)410 Value getOperand(unsigned idx) { return op.getOperand(opStart + idx); } 411 412 using operand_iterator = AffineForOp::operand_iterator; 413 using operand_range = AffineForOp::operand_range; 414 operand_begin()415 operand_iterator operand_begin() { return op.operand_begin() + opStart; } operand_end()416 operand_iterator operand_end() { return op.operand_begin() + opEnd; } getOperands()417 operand_range getOperands() { return {operand_begin(), operand_end()}; } 418 419 private: 420 // 'affine.for' operation that contains this bound. 421 AffineForOp op; 422 // Start and end positions of this affine bound operands in the list of 423 // the containing 'affine.for' operation operands. 424 unsigned opStart, opEnd; 425 // Affine map for this bound. 426 AffineMap map; 427 AffineBound(AffineForOp op,unsigned opStart,unsigned opEnd,AffineMap map)428 AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map) 429 : op(op), opStart(opStart), opEnd(opEnd), map(map) {} 430 431 friend class AffineForOp; 432 }; 433 434 /// An `AffineApplyNormalizer` is a helper class that supports renumbering 435 /// operands of AffineApplyOp. This acts as a reindexing map of Value to 436 /// positional dims or symbols and allows simplifications such as: 437 /// 438 /// ```mlir 439 /// %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) 440 /// ``` 441 /// 442 /// into: 443 /// 444 /// ```mlir 445 /// %1 = affine.apply () -> (0) 446 /// ``` 447 struct AffineApplyNormalizer { 448 AffineApplyNormalizer(AffineMap map, ArrayRef<Value> operands); 449 450 /// Returns the AffineMap resulting from normalization. getAffineMapAffineApplyNormalizer451 AffineMap getAffineMap() { return affineMap; } 452 getOperandsAffineApplyNormalizer453 SmallVector<Value, 8> getOperands() { 454 SmallVector<Value, 8> res(reorderedDims); 455 res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); 456 return res; 457 } 458 getNumSymbolsAffineApplyNormalizer459 unsigned getNumSymbols() { return concatenatedSymbols.size(); } getNumDimsAffineApplyNormalizer460 unsigned getNumDims() { return reorderedDims.size(); } 461 462 /// Normalizes 'otherMap' and its operands 'otherOperands' to map to this 463 /// normalizer's coordinate space. 464 void normalize(AffineMap *otherMap, SmallVectorImpl<Value> *otherOperands); 465 466 private: 467 /// Helper function to insert `v` into the coordinate system of the current 468 /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding 469 /// renumbered position. 470 AffineDimExpr renumberOneDim(Value v); 471 472 /// Given an `other` normalizer, this rewrites `other.affineMap` in the 473 /// coordinate system of the current AffineApplyNormalizer. 474 /// Returns the rewritten AffineMap and updates the dims and symbols of 475 /// `this`. 476 AffineMap renumber(const AffineApplyNormalizer &other); 477 478 /// Maps of Value to position in `affineMap`. 479 DenseMap<Value, unsigned> dimValueToPosition; 480 481 /// Ordered dims and symbols matching positional dims and symbols in 482 /// `affineMap`. 483 SmallVector<Value, 8> reorderedDims; 484 SmallVector<Value, 8> concatenatedSymbols; 485 486 /// The number of symbols in concatenated symbols that belong to the original 487 /// map as opposed to those concatendated during map composition. 488 unsigned numProperSymbols; 489 490 AffineMap affineMap; 491 492 /// Used with RAII to control the depth at which AffineApply are composed 493 /// recursively. Only accepts depth 1 for now to allow a behavior where a 494 /// newly composed AffineApplyOp does not increase the length of the chain of 495 /// AffineApplyOps. Full composition is implemented iteratively on top of 496 /// this behavior. affineApplyDepthAffineApplyNormalizer497 static unsigned &affineApplyDepth() { 498 static thread_local unsigned depth = 0; 499 return depth; 500 } 501 static constexpr unsigned kMaxAffineApplyDepth = 1; 502 AffineApplyNormalizerAffineApplyNormalizer503 AffineApplyNormalizer() : numProperSymbols(0) { affineApplyDepth()++; } 504 505 public: ~AffineApplyNormalizerAffineApplyNormalizer506 ~AffineApplyNormalizer() { affineApplyDepth()--; } 507 }; 508 509 } // end namespace mlir 510 511 #endif 512