1 //===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_ 10 #define MLIR_DIALECT_LINALG_LINALGTRAITS_H_ 11 12 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 13 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Function.h" 16 #include "mlir/IR/OpDefinition.h" 17 #include "mlir/IR/StandardTypes.h" 18 #include "mlir/Support/LLVM.h" 19 20 namespace mlir { 21 namespace OpTrait { 22 namespace linalg { 23 24 /// This class provides the API for ops that are known to have a specified 25 /// number of inputs, all passed as operands. Use as a trait as follows: 26 /// 27 /// class DotOp : public Op<DotOp, OpTrait::NInputs<2>::Impl> { 28 /// 29 template <unsigned N> class NInputs { 30 public: 31 template <typename ConcreteType> 32 class Impl : public OpTrait::TraitBase<ConcreteType, NInputs<N>::Impl> { 33 public: getNumInputs()34 static unsigned getNumInputs() { return N; } 35 }; 36 }; 37 38 /// This class provides the API for ops that are known to have a specified 39 /// number of outputs, all passed as operands. Use as a trait as follows: 40 /// 41 /// class DotOp : public Op<DotOp, OpTrait::NOutputs<2>::Impl> { 42 /// 43 template <unsigned N> class NOutputs { 44 public: 45 template <typename ConcreteType> 46 class Impl : public OpTrait::TraitBase<ConcreteType, NOutputs<N>::Impl> { 47 public: getNumOutputs()48 static unsigned getNumOutputs() { return N; } 49 }; 50 }; 51 52 /// This class provides the API for structured ops that are known to operate on 53 /// buffers or tensors. This trait must be used in conjunction with an op 54 /// definition or a trait that provides the methods `getNumInputs` and 55 /// `getNumOutputs`. Use as a trait as follows: 56 /// 57 /// class DotOp : public Op<DotOp, OpTrait::StructuredOpTraits> { 58 /// 59 template <typename ConcreteType> 60 class StructuredOpTraits 61 : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> { 62 private: 63 /// Return the number of inputs, irrespective of their buffer or tensor type. 64 /// For internal use only. nInputs()65 unsigned nInputs() { 66 return cast<ConcreteType>(this->getOperation()).getNumInputs(); 67 } 68 /// Return the number of outputs, irrespective of their buffer or tensor type. 69 /// For internal use only. nOutputs()70 unsigned nOutputs() { 71 return cast<ConcreteType>(this->getOperation()).getNumOutputs(); 72 } 73 74 public: 75 //==========================================================================// 76 // Loop types handling. 77 //==========================================================================// getNumParallelLoops()78 unsigned getNumParallelLoops() { 79 return getNumIterators( 80 getParallelIteratorTypeName(), 81 cast<ConcreteType>(this->getOperation()).iterator_types()); 82 } getNumReductionLoops()83 unsigned getNumReductionLoops() { 84 return getNumIterators( 85 getReductionIteratorTypeName(), 86 cast<ConcreteType>(this->getOperation()).iterator_types()); 87 } getNumWindowLoops()88 unsigned getNumWindowLoops() { 89 return getNumIterators( 90 getWindowIteratorTypeName(), 91 cast<ConcreteType>(this->getOperation()).iterator_types()); 92 } getNumLoops()93 unsigned getNumLoops() { 94 return getNumIterators( 95 cast<ConcreteType>(this->getOperation()).iterator_types()); 96 } 97 hasSingleReductionLoop()98 bool hasSingleReductionLoop() { 99 auto iterators = cast<ConcreteType>(this->getOperation()).iterator_types(); 100 return iterators.size() == 1 && 101 getNumIterators(getReductionIteratorTypeName(), iterators); 102 } 103 104 //==========================================================================// 105 // Input arguments handling. 106 //==========================================================================// 107 // The `i^th` input argument is always the `i^th` operand regardless of 108 // whether we have tensors or buffers. 109 // 110 /// Return the `i`-th input value. getInput(unsigned i)111 Value getInput(unsigned i) { 112 assert(i < nInputs()); 113 return this->getOperation()->getOperand(i); 114 } 115 /// Return the index of `value` in the list of inputs if found, llvm::None 116 /// otherwise. getIndexOfInput(Value value)117 Optional<unsigned> getIndexOfInput(Value value) { 118 auto it = llvm::find(getInputs(), value); 119 if (it != getInputs().end()) 120 return it - getInputs().begin(); 121 return llvm::None; 122 } 123 /// Return the `i`-th input shaped type, irrespective of buffer or tensor 124 /// type. getInputShapedType(unsigned i)125 ShapedType getInputShapedType(unsigned i) { 126 return getInput(i).getType().template cast<ShapedType>(); 127 } 128 /// Return the range over inputs. getInputs()129 Operation::operand_range getInputs() { 130 auto range = this->getOperation()->getOperands(); 131 return {range.begin(), range.begin() + nInputs()}; 132 } 133 /// Query the subset of input operands that are of ranked tensor type. getInputTensorTypes()134 SmallVector<RankedTensorType, 4> getInputTensorTypes() { 135 SmallVector<RankedTensorType, 4> res; 136 for (Type type : getInputs().getTypes()) 137 if (auto t = type.template dyn_cast<RankedTensorType>()) 138 res.push_back(t); 139 return res; 140 } 141 142 //==========================================================================// 143 // Output arguments handling. 144 //==========================================================================// 145 // The `i^th` output argument is an operand (resp. a return value) iff it is 146 // a value of buffer type (resp. a return value of tensor type). 147 148 /// Return the `i`-th output, asserts that this is a buffer operand and not 149 /// a tensor result. getOutputBuffer(unsigned i)150 Value getOutputBuffer(unsigned i) { 151 assert(i + this->getOperation()->getNumResults() < nOutputs() && 152 "overflowing output buffer index"); 153 return this->getOperation()->getOperand(nInputs() + i); 154 } 155 /// Return the index of `value` in the list of output buffers if found, 156 /// llvm::None otherwise. getIndexOfOutputBuffer(Value value)157 Optional<unsigned> getIndexOfOutputBuffer(Value value) { 158 auto it = llvm::find(getOutputBuffers(), value); 159 if (it != getOutputBuffers().end()) 160 return it - getOutputBuffers().begin(); 161 return llvm::None; 162 } 163 /// Return the `i`-th output buffer type. getOutputBufferType(unsigned i)164 MemRefType getOutputBufferType(unsigned i) { 165 return getOutputBuffer(i).getType().template cast<MemRefType>(); 166 } 167 /// Return the `i`-th output shaped type, irrespective of buffer of tensor 168 /// type. getOutputShapedType(unsigned i)169 ShapedType getOutputShapedType(unsigned i) { 170 return getShapedType(i + nInputs()); 171 } 172 /// Query the subset of results that are of ranked tensor type. getOutputTensorTypes()173 SmallVector<RankedTensorType, 4> getOutputTensorTypes() { 174 SmallVector<RankedTensorType, 4> res; 175 for (Type type : this->getOperation()->getResults().getTypes()) 176 res.push_back(type.template cast<RankedTensorType>()); 177 return res; 178 } 179 /// Return the range over outputs. getOutputBuffers()180 Operation::operand_range getOutputBuffers() { 181 auto range = this->getOperation()->getOperands(); 182 return {range.begin() + nInputs(), 183 range.begin() + getNumInputsAndOutputBuffers()}; 184 } 185 186 //==========================================================================// 187 // Input and Output arguments handling. 188 //==========================================================================// getBuffer(unsigned i)189 Value getBuffer(unsigned i) { 190 assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index"); 191 return this->getOperation()->getOperand(i); 192 } 193 /// Return the number of inputs and outputs, irrespective of their buffer or 194 /// tensor type. getNumInputsAndOutputs()195 unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } 196 /// Return the number of inputs, irrespective of their buffer or tensor type, 197 /// and output buffers. getNumInputsAndOutputBuffers()198 unsigned getNumInputsAndOutputBuffers() { 199 assert(this->getOperation()->getNumResults() <= nOutputs()); 200 return nInputs() + nOutputs() - this->getOperation()->getNumResults(); 201 } 202 /// Return the range over inputs (irrespective of type) and output buffers. getInputsAndOutputBuffers()203 Operation::operand_range getInputsAndOutputBuffers() { 204 auto range = this->getOperation()->getOperands(); 205 return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; 206 } 207 /// Return the `i`-th shaped type, there are 3 cases: 208 /// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise 209 /// 2. if `i < getNumInputsAndOutputBuffers()` then return the 210 /// `getOutputBufferType(i - nInputs())`; otherwise 211 /// 3. return the `i - getNumInputsAndOutputBuffers()` result type. getShapedType(unsigned i)212 ShapedType getShapedType(unsigned i) { 213 if (i < nInputs()) 214 return getInputShapedType(i); 215 if (i < getNumInputsAndOutputBuffers()) 216 return getOutputBufferType(i - nInputs()).template cast<ShapedType>(); 217 return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()] 218 .template cast<ShapedType>(); 219 } 220 /// Return the shaped types for all the inputs and outputs getInputOutputShapedTypes()221 SmallVector<ShapedType, 4> getInputOutputShapedTypes() { 222 SmallVector<Type, 4> inputOutputTypes( 223 this->getOperation()->operand_type_begin(), 224 this->getOperation()->operand_type_end()); 225 inputOutputTypes.append(this->getOperation()->result_type_begin(), 226 this->getOperation()->result_type_end()); 227 return llvm::to_vector<4>( 228 llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType { 229 return type.cast<ShapedType>(); 230 })); 231 } 232 233 //==========================================================================// 234 // Other interface methods. 235 //==========================================================================// 236 237 // Get or build the indexing_maps ArrayAttr. iterator_types()238 ArrayAttr iterator_types() { 239 // Return the attribute if it is present. 240 if (auto attr = this->getOperation()->getAttr("iterator_types")) 241 return attr.template cast<ArrayAttr>(); 242 243 // If not, form the attribute using the reference iterator types for the 244 // ConcreteType. 245 auto maybeReferenceIteratorTypes = 246 cast<ConcreteType>(this->getOperation()).referenceIterators(); 247 248 // If there is no reference, this must be a generic op. 249 // TODO: Traits are used to define ops. Split into cpp to avoid cyclic 250 // dependency. 251 auto name = this->getOperation()->getName().getStringRef(); 252 if (!maybeReferenceIteratorTypes && name != "generic" && 253 name != "indexed_generic") { 254 this->getOperation()->dump(); 255 llvm_unreachable("Op missing referenceIterators"); 256 } 257 258 // If we have a reference, build the reference attribute and set it in the 259 // op before returning. 260 auto *ctx = this->getOperation()->getContext(); 261 auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes, 262 [ctx](StringRef str) -> Attribute { 263 return StringAttr::get(str, ctx); 264 }); 265 auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx); 266 // TODO: Need to memoize this. Can't just store as an attribute atm as it 267 // will impact parser, printer and tests. 268 // this->getOperation()->setAttr("iterator_types", attr); 269 return attr; 270 } 271 272 // Get or build the indexing_maps ArrayAttr. indexing_maps()273 ArrayAttr indexing_maps() { 274 // Return the attribute if it is present. 275 if (auto attr = this->getOperation()->getAttr("indexing_maps")) 276 return attr.template cast<ArrayAttr>(); 277 278 // If not, form the attribute using the reference indexing map for the 279 // ConcreteType. 280 auto maybeReferenceIndexingMaps = 281 cast<ConcreteType>(this->getOperation()).referenceIndexingMaps(); 282 283 // If there is no reference, this must be a generic op. 284 auto name = this->getOperation()->getName().getStringRef(); 285 if (!maybeReferenceIndexingMaps && name != "generic" && 286 name != "indexed_generic") { 287 this->getOperation()->dump(); 288 llvm_unreachable("Op missing referenceIndexingMaps"); 289 } 290 291 // If we have a reference, build the reference attribute and set it in the 292 // op before returning. 293 auto *ctx = this->getOperation()->getContext(); 294 auto attrRange = 295 llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) { 296 // 0-D corner case because there is no such thing as a concrete empty 297 // map type. 298 if (!map) 299 map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx)); 300 return AffineMapAttr::get(map); 301 }); 302 SmallVector<Attribute, 4> attrs{attrRange.begin(), attrRange.end()}; 303 auto attr = ArrayAttr::get(attrs, ctx); 304 // TODO: Need to memoize this. Can't just store as an attribute atm as it 305 // will impact parser, printer and tests. 306 // this->getOperation()->setAttr("indexing_maps", attr); 307 return attr; 308 } 309 getIndexingMaps()310 SmallVector<AffineMap, 4> getIndexingMaps() { 311 return llvm::to_vector<4>( 312 llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap { 313 return attr.cast<AffineMapAttr>().getValue(); 314 })); 315 } 316 getIndexingMap(unsigned i)317 AffineMap getIndexingMap(unsigned i) { 318 assert(i < getNumInputsAndOutputs()); 319 return indexing_maps() 320 .getValue()[i] 321 .template cast<AffineMapAttr>() 322 .getValue(); 323 } 324 getInputIndexingMap(unsigned i)325 AffineMap getInputIndexingMap(unsigned i) { 326 assert(i < nInputs()); 327 return indexing_maps() 328 .getValue()[i] 329 .template cast<AffineMapAttr>() 330 .getValue(); 331 } 332 getOutputIndexingMap(unsigned i)333 AffineMap getOutputIndexingMap(unsigned i) { 334 assert(i < nOutputs()); 335 return indexing_maps() 336 .getValue()[i + nInputs()] 337 .template cast<AffineMapAttr>() 338 .getValue(); 339 } 340 341 /// Query whether the op has only buffer inputs and no returns. hasBufferSemantics()342 bool hasBufferSemantics() { 343 return this->getOperation()->getNumResults() == 0 && 344 llvm::all_of(getInputs(), 345 [](Value v) { return v.getType().isa<MemRefType>(); }); 346 } 347 348 /// Query whether the op has only tensor inputs and outputs. hasTensorSemantics()349 bool hasTensorSemantics() { 350 auto isTensorType = [](Value v) { 351 return v.getType().isa<RankedTensorType>(); 352 }; 353 return llvm::all_of(getInputs(), isTensorType) && 354 llvm::all_of(this->getOperation()->getResults(), isTensorType); 355 } 356 357 //==========================================================================// 358 // Other static interface methods. 359 //==========================================================================// verifyTrait(Operation * op)360 static LogicalResult verifyTrait(Operation *op) { 361 auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputBuffers(); 362 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) 363 return failure(); 364 return success(); 365 } 366 }; 367 368 /// This class provides the API for named Linalg StructuredOps. 369 template <typename ConcreteType> 370 class NamedStructuredOpTraits 371 : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTraits> { 372 public: 373 static SmallVector<StringRef, 8> referenceIterators(TypeRange inputTypes, 374 TypeRange outputTypes); 375 376 static SmallVector<AffineMap, 8> referenceIndexingMaps(TypeRange inputTypes, 377 TypeRange outputTypes); 378 }; 379 380 } // namespace linalg 381 } // namespace OpTrait 382 } // namespace mlir 383 384 #endif // MLIR_DIALECT_LINALG_LINALGTRAITS_H_ 385