1 //===- AffineOps.h - MLIR Affine Operations -------------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines convenience types for working with Affine operations 10 // in the MLIR operation set. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_DIALECT_AFFINEOPS_AFFINEOPS_H 15 #define MLIR_DIALECT_AFFINEOPS_AFFINEOPS_H 16 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/IR/OpDefinition.h" 21 #include "mlir/IR/StandardTypes.h" 22 #include "mlir/Transforms/LoopLikeInterface.h" 23 24 namespace mlir { 25 class AffineBound; 26 class AffineDimExpr; 27 class AffineValueMap; 28 class AffineTerminatorOp; 29 class FlatAffineConstraints; 30 class OpBuilder; 31 32 /// A utility function to check if a value is defined at the top level of a 33 /// function. A value of index type defined at the top level is always a valid 34 /// symbol. 35 bool isTopLevelValue(Value value); 36 37 class AffineOpsDialect : public Dialect { 38 public: 39 AffineOpsDialect(MLIRContext *context); getDialectNamespace()40 static StringRef getDialectNamespace() { return "affine"; } 41 42 /// Materialize a single constant operation from a given attribute value with 43 /// the desired resultant type. 44 Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, 45 Location loc) override; 46 }; 47 48 /// The "affine.apply" operation applies an affine map to a list of operands, 49 /// yielding a single result. The operand list must be the same size as the 50 /// number of arguments to the affine mapping. All operands and the result are 51 /// of type 'Index'. This operation requires a single affine map attribute named 52 /// "map". For example: 53 /// 54 /// %y = "affine.apply" (%x) { map: (d0) -> (d0 + 1) } : 55 /// (index) -> (index) 56 /// 57 /// equivalently: 58 /// 59 /// #map42 = (d0)->(d0+1) 60 /// %y = affine.apply #map42(%x) 61 /// 62 class AffineApplyOp : public Op<AffineApplyOp, OpTrait::VariadicOperands, 63 OpTrait::OneResult, OpTrait::HasNoSideEffect> { 64 public: 65 using Op::Op; 66 67 /// Builds an affine apply op with the specified map and operands. 68 static void build(Builder *builder, OperationState &result, AffineMap map, 69 ValueRange operands); 70 71 /// Returns the affine map to be applied by this operation. getAffineMap()72 AffineMap getAffineMap() { 73 return getAttrOfType<AffineMapAttr>("map").getValue(); 74 } 75 76 /// Returns true if the result of this operation can be used as dimension id. 77 bool isValidDim(); 78 79 /// Returns true if the result of this operation is a symbol. 80 bool isValidSymbol(); 81 getOperationName()82 static StringRef getOperationName() { return "affine.apply"; } 83 getMapOperands()84 operand_range getMapOperands() { return getOperands(); } 85 86 // Hooks to customize behavior of this op. 87 static ParseResult parse(OpAsmParser &parser, OperationState &result); 88 void print(OpAsmPrinter &p); 89 LogicalResult verify(); 90 OpFoldResult fold(ArrayRef<Attribute> operands); 91 92 static void getCanonicalizationPatterns(OwningRewritePatternList &results, 93 MLIRContext *context); 94 }; 95 96 /// AffineDmaStartOp starts a non-blocking DMA operation that transfers data 97 /// from a source memref to a destination memref. The source and destination 98 /// memref need not be of the same dimensionality, but need to have the same 99 /// elemental type. The operands include the source and destination memref's 100 /// each followed by its indices, size of the data transfer in terms of the 101 /// number of elements (of the elemental type of the memref), a tag memref with 102 /// its indices, and optionally at the end, a stride and a 103 /// number_of_elements_per_stride arguments. The tag location is used by an 104 /// AffineDmaWaitOp to check for completion. The indices of the source memref, 105 /// destination memref, and the tag memref have the same restrictions as any 106 /// affine.load/store. In particular, index for each memref dimension must be an 107 /// affine expression of loop induction variables and symbols. 108 /// The optional stride arguments should be of 'index' type, and specify a 109 /// stride for the slower memory space (memory space with a lower memory space 110 /// id), transferring chunks of number_of_elements_per_stride every stride until 111 /// %num_elements are transferred. Either both or no stride arguments should be 112 /// specified. The value of 'num_elements' must be a multiple of 113 /// 'number_of_elements_per_stride'. 114 // 115 // For example, a DmaStartOp operation that transfers 256 elements of a memref 116 // '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory 117 // space 1 at indices [%k + 7, %l], would be specified as follows: 118 // 119 // %num_elements = constant 256 120 // %idx = constant 0 : index 121 // %tag = alloc() : memref<1xi32, 4> 122 // affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], 123 // %num_elements : 124 // memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> 125 // 126 // If %stride and %num_elt_per_stride are specified, the DMA is expected to 127 // transfer %num_elt_per_stride elements every %stride elements apart from 128 // memory space 0 until %num_elements are transferred. 129 // 130 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements, 131 // %stride, %num_elt_per_stride : ... 132 // 133 // TODO(mlir-team): add additional operands to allow source and destination 134 // striding, and multiple stride levels (possibly using AffineMaps to specify 135 // multiple levels of striding). 136 // TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. 137 class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands, 138 OpTrait::ZeroResult> { 139 public: 140 using Op::Op; 141 142 static void build(Builder *builder, OperationState &result, Value srcMemRef, 143 AffineMap srcMap, ValueRange srcIndices, Value destMemRef, 144 AffineMap dstMap, ValueRange destIndices, Value tagMemRef, 145 AffineMap tagMap, ValueRange tagIndices, Value numElements, 146 Value stride = nullptr, Value elementsPerStride = nullptr); 147 148 /// Returns the operand index of the src memref. getSrcMemRefOperandIndex()149 unsigned getSrcMemRefOperandIndex() { return 0; } 150 151 /// Returns the source MemRefType for this DMA operation. getSrcMemRef()152 Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); } getSrcMemRefType()153 MemRefType getSrcMemRefType() { 154 return getSrcMemRef().getType().cast<MemRefType>(); 155 } 156 157 /// Returns the rank (number of indices) of the source MemRefType. getSrcMemRefRank()158 unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); } 159 160 /// Returns the affine map used to access the src memref. getSrcMap()161 AffineMap getSrcMap() { return getSrcMapAttr().getValue(); } getSrcMapAttr()162 AffineMapAttr getSrcMapAttr() { 163 return getAttr(getSrcMapAttrName()).cast<AffineMapAttr>(); 164 } 165 166 /// Returns the source memref affine map indices for this DMA operation. getSrcIndices()167 operand_range getSrcIndices() { 168 return {operand_begin() + getSrcMemRefOperandIndex() + 1, 169 operand_begin() + getSrcMemRefOperandIndex() + 1 + 170 getSrcMap().getNumInputs()}; 171 } 172 173 /// Returns the memory space of the src memref. getSrcMemorySpace()174 unsigned getSrcMemorySpace() { 175 return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace(); 176 } 177 178 /// Returns the operand index of the dst memref. getDstMemRefOperandIndex()179 unsigned getDstMemRefOperandIndex() { 180 return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs(); 181 } 182 183 /// Returns the destination MemRefType for this DMA operations. getDstMemRef()184 Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } getDstMemRefType()185 MemRefType getDstMemRefType() { 186 return getDstMemRef().getType().cast<MemRefType>(); 187 } 188 189 /// Returns the rank (number of indices) of the destination MemRefType. getDstMemRefRank()190 unsigned getDstMemRefRank() { 191 return getDstMemRef().getType().cast<MemRefType>().getRank(); 192 } 193 194 /// Returns the memory space of the src memref. getDstMemorySpace()195 unsigned getDstMemorySpace() { 196 return getDstMemRef().getType().cast<MemRefType>().getMemorySpace(); 197 } 198 199 /// Returns the affine map used to access the dst memref. getDstMap()200 AffineMap getDstMap() { return getDstMapAttr().getValue(); } getDstMapAttr()201 AffineMapAttr getDstMapAttr() { 202 return getAttr(getDstMapAttrName()).cast<AffineMapAttr>(); 203 } 204 205 /// Returns the destination memref indices for this DMA operation. getDstIndices()206 operand_range getDstIndices() { 207 return {operand_begin() + getDstMemRefOperandIndex() + 1, 208 operand_begin() + getDstMemRefOperandIndex() + 1 + 209 getDstMap().getNumInputs()}; 210 } 211 212 /// Returns the operand index of the tag memref. getTagMemRefOperandIndex()213 unsigned getTagMemRefOperandIndex() { 214 return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs(); 215 } 216 217 /// Returns the Tag MemRef for this DMA operation. getTagMemRef()218 Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } getTagMemRefType()219 MemRefType getTagMemRefType() { 220 return getTagMemRef().getType().cast<MemRefType>(); 221 } 222 223 /// Returns the rank (number of indices) of the tag MemRefType. getTagMemRefRank()224 unsigned getTagMemRefRank() { 225 return getTagMemRef().getType().cast<MemRefType>().getRank(); 226 } 227 228 /// Returns the affine map used to access the tag memref. getTagMap()229 AffineMap getTagMap() { return getTagMapAttr().getValue(); } getTagMapAttr()230 AffineMapAttr getTagMapAttr() { 231 return getAttr(getTagMapAttrName()).cast<AffineMapAttr>(); 232 } 233 234 /// Returns the tag memref indices for this DMA operation. getTagIndices()235 operand_range getTagIndices() { 236 return {operand_begin() + getTagMemRefOperandIndex() + 1, 237 operand_begin() + getTagMemRefOperandIndex() + 1 + 238 getTagMap().getNumInputs()}; 239 } 240 241 /// Returns the number of elements being transferred by this DMA operation. getNumElements()242 Value getNumElements() { 243 return getOperand(getTagMemRefOperandIndex() + 1 + 244 getTagMap().getNumInputs()); 245 } 246 247 /// Returns the AffineMapAttr associated with 'memref'. getAffineMapAttrForMemRef(Value memref)248 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 249 if (memref == getSrcMemRef()) 250 return {Identifier::get(getSrcMapAttrName(), getContext()), 251 getSrcMapAttr()}; 252 else if (memref == getDstMemRef()) 253 return {Identifier::get(getDstMapAttrName(), getContext()), 254 getDstMapAttr()}; 255 assert(memref == getTagMemRef() && 256 "DmaStartOp expected source, destination or tag memref"); 257 return {Identifier::get(getTagMapAttrName(), getContext()), 258 getTagMapAttr()}; 259 } 260 261 /// Returns true if this is a DMA from a faster memory space to a slower one. isDestMemorySpaceFaster()262 bool isDestMemorySpaceFaster() { 263 return (getSrcMemorySpace() < getDstMemorySpace()); 264 } 265 266 /// Returns true if this is a DMA from a slower memory space to a faster one. isSrcMemorySpaceFaster()267 bool isSrcMemorySpaceFaster() { 268 // Assumes that a lower number is for a slower memory space. 269 return (getDstMemorySpace() < getSrcMemorySpace()); 270 } 271 272 /// Given a DMA start operation, returns the operand position of either the 273 /// source or destination memref depending on the one that is at the higher 274 /// level of the memory hierarchy. Asserts failure if neither is true. getFasterMemPos()275 unsigned getFasterMemPos() { 276 assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); 277 return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex(); 278 } 279 getSrcMapAttrName()280 static StringRef getSrcMapAttrName() { return "src_map"; } getDstMapAttrName()281 static StringRef getDstMapAttrName() { return "dst_map"; } getTagMapAttrName()282 static StringRef getTagMapAttrName() { return "tag_map"; } 283 getOperationName()284 static StringRef getOperationName() { return "affine.dma_start"; } 285 static ParseResult parse(OpAsmParser &parser, OperationState &result); 286 void print(OpAsmPrinter &p); 287 LogicalResult verify(); 288 LogicalResult fold(ArrayRef<Attribute> cstOperands, 289 SmallVectorImpl<OpFoldResult> &results); 290 291 /// Returns true if this DMA operation is strided, returns false otherwise. isStrided()292 bool isStrided() { 293 return getNumOperands() != 294 getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1; 295 } 296 297 /// Returns the stride value for this DMA operation. getStride()298 Value getStride() { 299 if (!isStrided()) 300 return nullptr; 301 return getOperand(getNumOperands() - 1 - 1); 302 } 303 304 /// Returns the number of elements to transfer per stride for this DMA op. getNumElementsPerStride()305 Value getNumElementsPerStride() { 306 if (!isStrided()) 307 return nullptr; 308 return getOperand(getNumOperands() - 1); 309 } 310 }; 311 312 /// AffineDmaWaitOp blocks until the completion of a DMA operation associated 313 /// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be 314 /// an index with the same restrictions as any load/store index. In particular, 315 /// index for each memref dimension must be an affine expression of loop 316 /// induction variables and symbols. %num_elements is the number of elements 317 /// associated with the DMA operation. For example: 318 // 319 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements : 320 // memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2> 321 // ... 322 // ... 323 // affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> 324 // 325 class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands, 326 OpTrait::ZeroResult> { 327 public: 328 using Op::Op; 329 330 static void build(Builder *builder, OperationState &result, Value tagMemRef, 331 AffineMap tagMap, ValueRange tagIndices, Value numElements); 332 getOperationName()333 static StringRef getOperationName() { return "affine.dma_wait"; } 334 335 // Returns the Tag MemRef associated with the DMA operation being waited on. getTagMemRef()336 Value getTagMemRef() { return getOperand(0); } getTagMemRefType()337 MemRefType getTagMemRefType() { 338 return getTagMemRef().getType().cast<MemRefType>(); 339 } 340 341 /// Returns the affine map used to access the tag memref. getTagMap()342 AffineMap getTagMap() { return getTagMapAttr().getValue(); } getTagMapAttr()343 AffineMapAttr getTagMapAttr() { 344 return getAttr(getTagMapAttrName()).cast<AffineMapAttr>(); 345 } 346 347 // Returns the tag memref index for this DMA operation. getTagIndices()348 operand_range getTagIndices() { 349 return {operand_begin() + 1, 350 operand_begin() + 1 + getTagMap().getNumInputs()}; 351 } 352 353 // Returns the rank (number of indices) of the tag memref. getTagMemRefRank()354 unsigned getTagMemRefRank() { 355 return getTagMemRef().getType().cast<MemRefType>().getRank(); 356 } 357 358 /// Returns the AffineMapAttr associated with 'memref'. getAffineMapAttrForMemRef(Value memref)359 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 360 assert(memref == getTagMemRef()); 361 return {Identifier::get(getTagMapAttrName(), getContext()), 362 getTagMapAttr()}; 363 } 364 365 /// Returns the number of elements transferred in the associated DMA op. getNumElements()366 Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } 367 getTagMapAttrName()368 static StringRef getTagMapAttrName() { return "tag_map"; } 369 static ParseResult parse(OpAsmParser &parser, OperationState &result); 370 void print(OpAsmPrinter &p); 371 LogicalResult verify(); 372 LogicalResult fold(ArrayRef<Attribute> cstOperands, 373 SmallVectorImpl<OpFoldResult> &results); 374 }; 375 376 /// The "affine.load" op reads an element from a memref, where the index 377 /// for each memref dimension is an affine expression of loop induction 378 /// variables and symbols. The output of 'affine.load' is a new value with the 379 /// same type as the elements of the memref. An affine expression of loop IVs 380 /// and symbols must be specified for each dimension of the memref. The keyword 381 /// 'symbol' can be used to indicate SSA identifiers which are symbolic. 382 // 383 // Example 1: 384 // 385 // %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> 386 // 387 // Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. 388 // 389 // %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)] 390 // : memref<100x100xf32> 391 // 392 class AffineLoadOp : public Op<AffineLoadOp, OpTrait::OneResult, 393 OpTrait::AtLeastNOperands<1>::Impl> { 394 public: 395 using Op::Op; 396 397 /// Builds an affine load op with the specified map and operands. 398 static void build(Builder *builder, OperationState &result, AffineMap map, 399 ValueRange operands); 400 /// Builds an affine load op with an identity map and operands. 401 static void build(Builder *builder, OperationState &result, Value memref, 402 ValueRange indices = {}); 403 /// Builds an affine load op with the specified map and its operands. 404 static void build(Builder *builder, OperationState &result, Value memref, 405 AffineMap map, ValueRange mapOperands); 406 407 /// Returns the operand index of the memref. getMemRefOperandIndex()408 unsigned getMemRefOperandIndex() { return 0; } 409 410 /// Get memref operand. getMemRef()411 Value getMemRef() { return getOperand(getMemRefOperandIndex()); } setMemRef(Value value)412 void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } getMemRefType()413 MemRefType getMemRefType() { 414 return getMemRef().getType().cast<MemRefType>(); 415 } 416 417 /// Get affine map operands. getMapOperands()418 operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); } 419 420 /// Returns the affine map used to index the memref for this operation. getAffineMap()421 AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } getAffineMapAttr()422 AffineMapAttr getAffineMapAttr() { 423 return getAttr(getMapAttrName()).cast<AffineMapAttr>(); 424 } 425 426 /// Returns the AffineMapAttr associated with 'memref'. getAffineMapAttrForMemRef(Value memref)427 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 428 assert(memref == getMemRef()); 429 return {Identifier::get(getMapAttrName(), getContext()), 430 getAffineMapAttr()}; 431 } 432 getMapAttrName()433 static StringRef getMapAttrName() { return "map"; } getOperationName()434 static StringRef getOperationName() { return "affine.load"; } 435 436 // Hooks to customize behavior of this op. 437 static ParseResult parse(OpAsmParser &parser, OperationState &result); 438 void print(OpAsmPrinter &p); 439 LogicalResult verify(); 440 static void getCanonicalizationPatterns(OwningRewritePatternList &results, 441 MLIRContext *context); 442 OpFoldResult fold(ArrayRef<Attribute> operands); 443 }; 444 445 /// The "affine.store" op writes an element to a memref, where the index 446 /// for each memref dimension is an affine expression of loop induction 447 /// variables and symbols. The 'affine.store' op stores a new value which is the 448 /// same type as the elements of the memref. An affine expression of loop IVs 449 /// and symbols must be specified for each dimension of the memref. The keyword 450 /// 'symbol' can be used to indicate SSA identifiers which are symbolic. 451 // 452 // Example 1: 453 // 454 // affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> 455 // 456 // Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. 457 // 458 // affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] 459 // : memref<100x100xf32> 460 // 461 class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult, 462 OpTrait::AtLeastNOperands<1>::Impl> { 463 public: 464 using Op::Op; 465 466 /// Builds an affine store operation with the provided indices (identity map). 467 static void build(Builder *builder, OperationState &result, 468 Value valueToStore, Value memref, ValueRange indices); 469 /// Builds an affine store operation with the specified map and its operands. 470 static void build(Builder *builder, OperationState &result, 471 Value valueToStore, Value memref, AffineMap map, 472 ValueRange mapOperands); 473 474 /// Get value to be stored by store operation. getValueToStore()475 Value getValueToStore() { return getOperand(0); } 476 477 /// Returns the operand index of the memref. getMemRefOperandIndex()478 unsigned getMemRefOperandIndex() { return 1; } 479 480 /// Get memref operand. getMemRef()481 Value getMemRef() { return getOperand(getMemRefOperandIndex()); } setMemRef(Value value)482 void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } 483 getMemRefType()484 MemRefType getMemRefType() { 485 return getMemRef().getType().cast<MemRefType>(); 486 } 487 488 /// Get affine map operands. getMapOperands()489 operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); } 490 491 /// Returns the affine map used to index the memref for this operation. getAffineMap()492 AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } getAffineMapAttr()493 AffineMapAttr getAffineMapAttr() { 494 return getAttr(getMapAttrName()).cast<AffineMapAttr>(); 495 } 496 497 /// Returns the AffineMapAttr associated with 'memref'. getAffineMapAttrForMemRef(Value memref)498 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 499 assert(memref == getMemRef()); 500 return {Identifier::get(getMapAttrName(), getContext()), 501 getAffineMapAttr()}; 502 } 503 getMapAttrName()504 static StringRef getMapAttrName() { return "map"; } getOperationName()505 static StringRef getOperationName() { return "affine.store"; } 506 507 // Hooks to customize behavior of this op. 508 static ParseResult parse(OpAsmParser &parser, OperationState &result); 509 void print(OpAsmPrinter &p); 510 LogicalResult verify(); 511 static void getCanonicalizationPatterns(OwningRewritePatternList &results, 512 MLIRContext *context); 513 LogicalResult fold(ArrayRef<Attribute> cstOperands, 514 SmallVectorImpl<OpFoldResult> &results); 515 }; 516 517 /// Returns true if the given Value can be used as a dimension id. 518 bool isValidDim(Value value); 519 520 /// Returns true if the given Value can be used as a symbol. 521 bool isValidSymbol(Value value); 522 523 /// Modifies both `map` and `operands` in-place so as to: 524 /// 1. drop duplicate operands 525 /// 2. drop unused dims and symbols from map 526 /// 3. promote valid symbols to symbolic operands in case they appeared as 527 /// dimensional operands 528 /// 4. propagate constant operands and drop them 529 void canonicalizeMapAndOperands(AffineMap *map, 530 SmallVectorImpl<Value> *operands); 531 /// Canonicalizes an integer set the same way canonicalizeMapAndOperands does 532 /// for affine maps. 533 void canonicalizeSetAndOperands(IntegerSet *set, 534 SmallVectorImpl<Value> *operands); 535 536 /// Returns a composed AffineApplyOp by composing `map` and `operands` with 537 /// other AffineApplyOps supplying those operands. The operands of the resulting 538 /// AffineApplyOp do not change the length of AffineApplyOp chains. 539 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, 540 ArrayRef<Value> operands); 541 542 /// Given an affine map `map` and its input `operands`, this method composes 543 /// into `map`, maps of AffineApplyOps whose results are the values in 544 /// `operands`, iteratively until no more of `operands` are the result of an 545 /// AffineApplyOp. When this function returns, `map` becomes the composed affine 546 /// map, and each Value in `operands` is guaranteed to be either a loop IV or a 547 /// terminal symbol, i.e., a symbol defined at the top level or a block/function 548 /// argument. 549 void fullyComposeAffineMapAndOperands(AffineMap *map, 550 SmallVectorImpl<Value> *operands); 551 552 #define GET_OP_CLASSES 553 #include "mlir/Dialect/AffineOps/AffineOps.h.inc" 554 555 /// Returns if the provided value is the induction variable of a AffineForOp. 556 bool isForInductionVar(Value val); 557 558 /// Returns the loop parent of an induction variable. If the provided value is 559 /// not an induction variable, then return nullptr. 560 AffineForOp getForInductionVarOwner(Value val); 561 562 /// Extracts the induction variables from a list of AffineForOps and places them 563 /// in the output argument `ivs`. 564 void extractForInductionVars(ArrayRef<AffineForOp> forInsts, 565 SmallVectorImpl<Value> *ivs); 566 567 /// AffineBound represents a lower or upper bound in the for operation. 568 /// This class does not own the underlying operands. Instead, it refers 569 /// to the operands stored in the AffineForOp. Its life span should not exceed 570 /// that of the for operation it refers to. 571 class AffineBound { 572 public: getAffineForOp()573 AffineForOp getAffineForOp() { return op; } getMap()574 AffineMap getMap() { return map; } 575 576 /// Returns an AffineValueMap representing this bound. 577 AffineValueMap getAsAffineValueMap(); 578 getNumOperands()579 unsigned getNumOperands() { return opEnd - opStart; } getOperand(unsigned idx)580 Value getOperand(unsigned idx) { return op.getOperand(opStart + idx); } 581 582 using operand_iterator = AffineForOp::operand_iterator; 583 using operand_range = AffineForOp::operand_range; 584 operand_begin()585 operand_iterator operand_begin() { return op.operand_begin() + opStart; } operand_end()586 operand_iterator operand_end() { return op.operand_begin() + opEnd; } getOperands()587 operand_range getOperands() { return {operand_begin(), operand_end()}; } 588 589 private: 590 // 'affine.for' operation that contains this bound. 591 AffineForOp op; 592 // Start and end positions of this affine bound operands in the list of 593 // the containing 'affine.for' operation operands. 594 unsigned opStart, opEnd; 595 // Affine map for this bound. 596 AffineMap map; 597 AffineBound(AffineForOp op,unsigned opStart,unsigned opEnd,AffineMap map)598 AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map) 599 : op(op), opStart(opStart), opEnd(opEnd), map(map) {} 600 601 friend class AffineForOp; 602 }; 603 604 /// An `AffineApplyNormalizer` is a helper class that supports renumbering 605 /// operands of AffineApplyOp. This acts as a reindexing map of Value to 606 /// positional dims or symbols and allows simplifications such as: 607 /// 608 /// ```mlir 609 /// %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) 610 /// ``` 611 /// 612 /// into: 613 /// 614 /// ```mlir 615 /// %1 = affine.apply () -> (0) 616 /// ``` 617 struct AffineApplyNormalizer { 618 AffineApplyNormalizer(AffineMap map, ArrayRef<Value> operands); 619 620 /// Returns the AffineMap resulting from normalization. getAffineMapAffineApplyNormalizer621 AffineMap getAffineMap() { return affineMap; } 622 getOperandsAffineApplyNormalizer623 SmallVector<Value, 8> getOperands() { 624 SmallVector<Value, 8> res(reorderedDims); 625 res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); 626 return res; 627 } 628 getNumSymbolsAffineApplyNormalizer629 unsigned getNumSymbols() { return concatenatedSymbols.size(); } getNumDimsAffineApplyNormalizer630 unsigned getNumDims() { return reorderedDims.size(); } 631 632 /// Normalizes 'otherMap' and its operands 'otherOperands' to map to this 633 /// normalizer's coordinate space. 634 void normalize(AffineMap *otherMap, SmallVectorImpl<Value> *otherOperands); 635 636 private: 637 /// Helper function to insert `v` into the coordinate system of the current 638 /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding 639 /// renumbered position. 640 AffineDimExpr renumberOneDim(Value v); 641 642 /// Given an `other` normalizer, this rewrites `other.affineMap` in the 643 /// coordinate system of the current AffineApplyNormalizer. 644 /// Returns the rewritten AffineMap and updates the dims and symbols of 645 /// `this`. 646 AffineMap renumber(const AffineApplyNormalizer &other); 647 648 /// Maps of Value to position in `affineMap`. 649 DenseMap<Value, unsigned> dimValueToPosition; 650 651 /// Ordered dims and symbols matching positional dims and symbols in 652 /// `affineMap`. 653 SmallVector<Value, 8> reorderedDims; 654 SmallVector<Value, 8> concatenatedSymbols; 655 656 AffineMap affineMap; 657 658 /// Used with RAII to control the depth at which AffineApply are composed 659 /// recursively. Only accepts depth 1 for now to allow a behavior where a 660 /// newly composed AffineApplyOp does not increase the length of the chain of 661 /// AffineApplyOps. Full composition is implemented iteratively on top of 662 /// this behavior. affineApplyDepthAffineApplyNormalizer663 static unsigned &affineApplyDepth() { 664 static thread_local unsigned depth = 0; 665 return depth; 666 } 667 static constexpr unsigned kMaxAffineApplyDepth = 1; 668 AffineApplyNormalizerAffineApplyNormalizer669 AffineApplyNormalizer() { affineApplyDepth()++; } 670 671 public: ~AffineApplyNormalizerAffineApplyNormalizer672 ~AffineApplyNormalizer() { affineApplyDepth()--; } 673 }; 674 675 } // end namespace mlir 676 677 #endif 678