1 //===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_
10 #define MLIR_DIALECT_LINALG_LINALGTRAITS_H_
11 
12 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
13 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Function.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/StandardTypes.h"
18 #include "mlir/Support/LLVM.h"
19 
20 namespace mlir {
21 namespace OpTrait {
22 namespace linalg {
23 
24 /// This class provides the API for ops that are known to have a specified
25 /// number of inputs, all passed as operands. Use as a trait as follows:
26 ///
27 ///   class DotOp : public Op<DotOp, OpTrait::NInputs<2>::Impl> {
28 ///
29 template <unsigned N> class NInputs {
30 public:
31   template <typename ConcreteType>
32   class Impl : public OpTrait::TraitBase<ConcreteType, NInputs<N>::Impl> {
33   public:
getNumInputs()34     static unsigned getNumInputs() { return N; }
35   };
36 };
37 
38 /// This class provides the API for ops that are known to have a specified
39 /// number of outputs, all passed as operands. Use as a trait as follows:
40 ///
41 ///   class DotOp : public Op<DotOp, OpTrait::NOutputs<2>::Impl> {
42 ///
43 template <unsigned N> class NOutputs {
44 public:
45   template <typename ConcreteType>
46   class Impl : public OpTrait::TraitBase<ConcreteType, NOutputs<N>::Impl> {
47   public:
getNumOutputs()48     static unsigned getNumOutputs() { return N; }
49   };
50 };
51 
52 /// This class provides the API for structured ops that are known to operate on
53 /// buffers or tensors. This trait must be used in conjunction with an op
54 /// definition or a trait that provides the methods `getNumInputs` and
55 /// `getNumOutputs`. Use as a trait as follows:
56 ///
57 ///   class DotOp : public Op<DotOp, OpTrait::StructuredOpTraits> {
58 ///
59 template <typename ConcreteType>
60 class StructuredOpTraits
61     : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
62 private:
63   /// Return the number of inputs, irrespective of their buffer or tensor type.
64   /// For internal use only.
nInputs()65   unsigned nInputs() {
66     return cast<ConcreteType>(this->getOperation()).getNumInputs();
67   }
68   /// Return the number of outputs, irrespective of their buffer or tensor type.
69   /// For internal use only.
nOutputs()70   unsigned nOutputs() {
71     return cast<ConcreteType>(this->getOperation()).getNumOutputs();
72   }
73 
74 public:
75   //==========================================================================//
76   // Loop types handling.
77   //==========================================================================//
getNumParallelLoops()78   unsigned getNumParallelLoops() {
79     return getNumIterators(
80         getParallelIteratorTypeName(),
81         cast<ConcreteType>(this->getOperation()).iterator_types());
82   }
getNumReductionLoops()83   unsigned getNumReductionLoops() {
84     return getNumIterators(
85         getReductionIteratorTypeName(),
86         cast<ConcreteType>(this->getOperation()).iterator_types());
87   }
getNumWindowLoops()88   unsigned getNumWindowLoops() {
89     return getNumIterators(
90         getWindowIteratorTypeName(),
91         cast<ConcreteType>(this->getOperation()).iterator_types());
92   }
getNumLoops()93   unsigned getNumLoops() {
94     return getNumIterators(
95         cast<ConcreteType>(this->getOperation()).iterator_types());
96   }
97 
hasSingleReductionLoop()98   bool hasSingleReductionLoop() {
99     auto iterators = cast<ConcreteType>(this->getOperation()).iterator_types();
100     return iterators.size() == 1 &&
101            getNumIterators(getReductionIteratorTypeName(), iterators);
102   }
103 
104   //==========================================================================//
105   // Input arguments handling.
106   //==========================================================================//
107   // The `i^th` input argument is always the `i^th` operand regardless of
108   // whether we have tensors or buffers.
109   //
110   /// Return the `i`-th input value.
getInput(unsigned i)111   Value getInput(unsigned i) {
112     assert(i < nInputs());
113     return this->getOperation()->getOperand(i);
114   }
115   /// Return the index of `value` in the list of inputs if found, llvm::None
116   /// otherwise.
getIndexOfInput(Value value)117   Optional<unsigned> getIndexOfInput(Value value) {
118     auto it = llvm::find(getInputs(), value);
119     if (it != getInputs().end())
120       return it - getInputs().begin();
121     return llvm::None;
122   }
123   /// Return the `i`-th input shaped type, irrespective of buffer or tensor
124   /// type.
getInputShapedType(unsigned i)125   ShapedType getInputShapedType(unsigned i) {
126     return getInput(i).getType().template cast<ShapedType>();
127   }
128   /// Return the range over inputs.
getInputs()129   Operation::operand_range getInputs() {
130     auto range = this->getOperation()->getOperands();
131     return {range.begin(), range.begin() + nInputs()};
132   }
133   /// Query the subset of input operands that are of ranked tensor type.
getInputTensorTypes()134   SmallVector<RankedTensorType, 4> getInputTensorTypes() {
135     SmallVector<RankedTensorType, 4> res;
136     for (Type type : getInputs().getTypes())
137       if (auto t = type.template dyn_cast<RankedTensorType>())
138         res.push_back(t);
139     return res;
140   }
141 
142   //==========================================================================//
143   // Output arguments handling.
144   //==========================================================================//
145   // The `i^th` output argument is an operand (resp. a return value) iff it is
146   // a value of buffer type (resp. a return value of tensor type).
147 
148   /// Return the `i`-th output, asserts that this is a buffer operand and not
149   /// a tensor result.
getOutputBuffer(unsigned i)150   Value getOutputBuffer(unsigned i) {
151     assert(i + this->getOperation()->getNumResults() < nOutputs() &&
152            "overflowing output buffer index");
153     return this->getOperation()->getOperand(nInputs() + i);
154   }
155   /// Return the index of `value` in the list of output buffers if found,
156   /// llvm::None otherwise.
getIndexOfOutputBuffer(Value value)157   Optional<unsigned> getIndexOfOutputBuffer(Value value) {
158     auto it = llvm::find(getOutputBuffers(), value);
159     if (it != getOutputBuffers().end())
160       return it - getOutputBuffers().begin();
161     return llvm::None;
162   }
163   /// Return the `i`-th output buffer type.
getOutputBufferType(unsigned i)164   MemRefType getOutputBufferType(unsigned i) {
165     return getOutputBuffer(i).getType().template cast<MemRefType>();
166   }
167   /// Return the `i`-th output shaped type, irrespective of buffer of tensor
168   /// type.
getOutputShapedType(unsigned i)169   ShapedType getOutputShapedType(unsigned i) {
170     return getShapedType(i + nInputs());
171   }
172   /// Query the subset of results that are of ranked tensor type.
getOutputTensorTypes()173   SmallVector<RankedTensorType, 4> getOutputTensorTypes() {
174     SmallVector<RankedTensorType, 4> res;
175     for (Type type : this->getOperation()->getResults().getTypes())
176       res.push_back(type.template cast<RankedTensorType>());
177     return res;
178   }
179   /// Return the range over outputs.
getOutputBuffers()180   Operation::operand_range getOutputBuffers() {
181     auto range = this->getOperation()->getOperands();
182     return {range.begin() + nInputs(),
183             range.begin() + getNumInputsAndOutputBuffers()};
184   }
185 
186   //==========================================================================//
187   // Input and Output arguments handling.
188   //==========================================================================//
getBuffer(unsigned i)189   Value getBuffer(unsigned i) {
190     assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
191     return this->getOperation()->getOperand(i);
192   }
193   /// Return the number of inputs and outputs, irrespective of their buffer or
194   /// tensor type.
getNumInputsAndOutputs()195   unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
196   /// Return the number of inputs, irrespective of their buffer or tensor type,
197   /// and output buffers.
getNumInputsAndOutputBuffers()198   unsigned getNumInputsAndOutputBuffers() {
199     assert(this->getOperation()->getNumResults() <= nOutputs());
200     return nInputs() + nOutputs() - this->getOperation()->getNumResults();
201   }
202   /// Return the range over inputs (irrespective of type) and output buffers.
getInputsAndOutputBuffers()203   Operation::operand_range getInputsAndOutputBuffers() {
204     auto range = this->getOperation()->getOperands();
205     return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
206   }
207   /// Return the `i`-th shaped type, there are 3 cases:
208   ///   1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise
209   ///   2. if `i < getNumInputsAndOutputBuffers()` then return the
210   ///      `getOutputBufferType(i - nInputs())`; otherwise
211   ///   3. return the `i - getNumInputsAndOutputBuffers()` result type.
getShapedType(unsigned i)212   ShapedType getShapedType(unsigned i) {
213     if (i < nInputs())
214       return getInputShapedType(i);
215     if (i < getNumInputsAndOutputBuffers())
216       return getOutputBufferType(i - nInputs()).template cast<ShapedType>();
217     return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]
218         .template cast<ShapedType>();
219   }
220   /// Return the shaped types for all the inputs and outputs
getInputOutputShapedTypes()221   SmallVector<ShapedType, 4> getInputOutputShapedTypes() {
222     SmallVector<Type, 4> inputOutputTypes(
223         this->getOperation()->operand_type_begin(),
224         this->getOperation()->operand_type_end());
225     inputOutputTypes.append(this->getOperation()->result_type_begin(),
226                             this->getOperation()->result_type_end());
227     return llvm::to_vector<4>(
228         llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
229           return type.cast<ShapedType>();
230         }));
231   }
232 
233   //==========================================================================//
234   // Other interface methods.
235   //==========================================================================//
236 
237   // Get or build the indexing_maps ArrayAttr.
iterator_types()238   ArrayAttr iterator_types() {
239     // Return the attribute if it is present.
240     if (auto attr = this->getOperation()->getAttr("iterator_types"))
241       return attr.template cast<ArrayAttr>();
242 
243     // If not, form the attribute using the reference iterator types for the
244     // ConcreteType.
245     auto maybeReferenceIteratorTypes =
246         cast<ConcreteType>(this->getOperation()).referenceIterators();
247 
248     // If there is no reference, this must be a generic op.
249     // TODO: Traits are used to define ops. Split into cpp to avoid cyclic
250     // dependency.
251     auto name = this->getOperation()->getName().getStringRef();
252     if (!maybeReferenceIteratorTypes && name != "generic" &&
253         name != "indexed_generic") {
254       this->getOperation()->dump();
255       llvm_unreachable("Op missing referenceIterators");
256     }
257 
258     // If we have a reference, build the reference attribute and set it in the
259     // op before returning.
260     auto *ctx = this->getOperation()->getContext();
261     auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes,
262                                      [ctx](StringRef str) -> Attribute {
263                                        return StringAttr::get(str, ctx);
264                                      });
265     auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx);
266     // TODO: Need to memoize this. Can't just store as an attribute atm as it
267     // will impact parser, printer and tests.
268     // this->getOperation()->setAttr("iterator_types", attr);
269     return attr;
270   }
271 
272   // Get or build the indexing_maps ArrayAttr.
indexing_maps()273   ArrayAttr indexing_maps() {
274     // Return the attribute if it is present.
275     if (auto attr = this->getOperation()->getAttr("indexing_maps"))
276       return attr.template cast<ArrayAttr>();
277 
278     // If not, form the attribute using the reference indexing map for the
279     // ConcreteType.
280     auto maybeReferenceIndexingMaps =
281         cast<ConcreteType>(this->getOperation()).referenceIndexingMaps();
282 
283     // If there is no reference, this must be a generic op.
284     auto name = this->getOperation()->getName().getStringRef();
285     if (!maybeReferenceIndexingMaps && name != "generic" &&
286         name != "indexed_generic") {
287       this->getOperation()->dump();
288       llvm_unreachable("Op missing referenceIndexingMaps");
289     }
290 
291     // If we have a reference, build the reference attribute and set it in the
292     // op before returning.
293     auto *ctx = this->getOperation()->getContext();
294     auto attrRange =
295         llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) {
296           // 0-D corner case because there is no such thing as a concrete empty
297           // map type.
298           if (!map)
299             map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx));
300           return AffineMapAttr::get(map);
301         });
302     SmallVector<Attribute, 4> attrs{attrRange.begin(), attrRange.end()};
303     auto attr = ArrayAttr::get(attrs, ctx);
304     // TODO: Need to memoize this. Can't just store as an attribute atm as it
305     // will impact parser, printer and tests.
306     // this->getOperation()->setAttr("indexing_maps", attr);
307     return attr;
308   }
309 
getIndexingMaps()310   SmallVector<AffineMap, 4> getIndexingMaps() {
311     return llvm::to_vector<4>(
312         llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap {
313           return attr.cast<AffineMapAttr>().getValue();
314         }));
315   }
316 
getIndexingMap(unsigned i)317   AffineMap getIndexingMap(unsigned i) {
318     assert(i < getNumInputsAndOutputs());
319     return indexing_maps()
320         .getValue()[i]
321         .template cast<AffineMapAttr>()
322         .getValue();
323   }
324 
getInputIndexingMap(unsigned i)325   AffineMap getInputIndexingMap(unsigned i) {
326     assert(i < nInputs());
327     return indexing_maps()
328         .getValue()[i]
329         .template cast<AffineMapAttr>()
330         .getValue();
331   }
332 
getOutputIndexingMap(unsigned i)333   AffineMap getOutputIndexingMap(unsigned i) {
334     assert(i < nOutputs());
335     return indexing_maps()
336         .getValue()[i + nInputs()]
337         .template cast<AffineMapAttr>()
338         .getValue();
339   }
340 
341   /// Query whether the op has only buffer inputs and no returns.
hasBufferSemantics()342   bool hasBufferSemantics() {
343     return this->getOperation()->getNumResults() == 0 &&
344            llvm::all_of(getInputs(),
345                         [](Value v) { return v.getType().isa<MemRefType>(); });
346   }
347 
348   /// Query whether the op has only tensor inputs and outputs.
hasTensorSemantics()349   bool hasTensorSemantics() {
350     auto isTensorType = [](Value v) {
351       return v.getType().isa<RankedTensorType>();
352     };
353     return llvm::all_of(getInputs(), isTensorType) &&
354            llvm::all_of(this->getOperation()->getResults(), isTensorType);
355   }
356 
357   //==========================================================================//
358   // Other static interface methods.
359   //==========================================================================//
verifyTrait(Operation * op)360   static LogicalResult verifyTrait(Operation *op) {
361     auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputBuffers();
362     if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
363       return failure();
364     return success();
365   }
366 };
367 
368 /// This class provides the API for named Linalg StructuredOps.
369 template <typename ConcreteType>
370 class NamedStructuredOpTraits
371     : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTraits> {
372 public:
373   static SmallVector<StringRef, 8> referenceIterators(TypeRange inputTypes,
374                                                       TypeRange outputTypes);
375 
376   static SmallVector<AffineMap, 8> referenceIndexingMaps(TypeRange inputTypes,
377                                                          TypeRange outputTypes);
378 };
379 
380 } // namespace linalg
381 } // namespace OpTrait
382 } // namespace mlir
383 
384 #endif // MLIR_DIALECT_LINALG_LINALGTRAITS_H_
385