1 //===- Operator.h - Operator class ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TABLEGEN_OPERATOR_H_
14 #define MLIR_TABLEGEN_OPERATOR_H_
15 
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/TableGen/Argument.h"
18 #include "mlir/TableGen/Attribute.h"
19 #include "mlir/TableGen/Dialect.h"
20 #include "mlir/TableGen/OpTrait.h"
21 #include "mlir/TableGen/Region.h"
22 #include "mlir/TableGen/Successor.h"
23 #include "mlir/TableGen/Type.h"
24 #include "llvm/ADT/PointerUnion.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringMap.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/SMLoc.h"
29 
30 namespace llvm {
31 class CodeInit;
32 class DefInit;
33 class Record;
34 class StringInit;
35 } // end namespace llvm
36 
37 namespace mlir {
38 namespace tblgen {
39 
40 // Wrapper class that contains a MLIR op's information (e.g., operands,
41 // attributes) defined in TableGen and provides helper methods for
42 // accessing them.
43 class Operator {
44 public:
45   explicit Operator(const llvm::Record &def);
Operator(const llvm::Record * def)46   explicit Operator(const llvm::Record *def) : Operator(*def) {}
47 
48   // Returns this op's dialect name.
49   StringRef getDialectName() const;
50 
51   // Returns the operation name. The name will follow the "<dialect>.<op-name>"
52   // format if its dialect name is not empty.
53   std::string getOperationName() const;
54 
55   // Returns this op's C++ class name.
56   StringRef getCppClassName() const;
57 
58   // Returns this op's C++ class name prefixed with namespaces.
59   std::string getQualCppClassName() const;
60 
61   // Returns the name of op's adaptor C++ class.
62   std::string getAdaptorName() const;
63 
64   /// A class used to represent the decorators of an operator variable, i.e.
65   /// argument or result.
66   struct VariableDecorator {
67   public:
VariableDecoratorVariableDecorator68     explicit VariableDecorator(const llvm::Record *def) : def(def) {}
getDefVariableDecorator69     const llvm::Record &getDef() const { return *def; }
70 
71   protected:
72     // The TableGen definition of this decorator.
73     const llvm::Record *def;
74   };
75 
76   // A utility iterator over a list of variable decorators.
77   struct VariableDecoratorIterator
78       : public llvm::mapped_iterator<llvm::Init *const *,
79                                      VariableDecorator (*)(llvm::Init *)> {
80     using reference = VariableDecorator;
81 
82     /// Initializes the iterator to the specified iterator.
VariableDecoratorIteratorVariableDecoratorIterator83     VariableDecoratorIterator(llvm::Init *const *it)
84         : llvm::mapped_iterator<llvm::Init *const *,
85                                 VariableDecorator (*)(llvm::Init *)>(it,
86                                                                      &unwrap) {}
87     static VariableDecorator unwrap(llvm::Init *init);
88   };
89   using var_decorator_iterator = VariableDecoratorIterator;
90   using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
91 
92   using value_iterator = NamedTypeConstraint *;
93   using value_range = llvm::iterator_range<value_iterator>;
94 
95   // Returns true if this op has variable length operands or results.
96   bool isVariadic() const;
97 
98   // Returns true if default builders should not be generated.
99   bool skipDefaultBuilders() const;
100 
101   // Op result iterators.
102   value_iterator result_begin();
103   value_iterator result_end();
104   value_range getResults();
105 
106   // Returns the number of results this op produces.
107   int getNumResults() const;
108 
109   // Returns the op result at the given `index`.
getResult(int index)110   NamedTypeConstraint &getResult(int index) { return results[index]; }
getResult(int index)111   const NamedTypeConstraint &getResult(int index) const {
112     return results[index];
113   }
114 
115   // Returns the `index`-th result's type constraint.
116   TypeConstraint getResultTypeConstraint(int index) const;
117   // Returns the `index`-th result's name.
118   StringRef getResultName(int index) const;
119   // Returns the `index`-th result's decorators.
120   var_decorator_range getResultDecorators(int index) const;
121 
122   // Returns the number of variable length results in this operation.
123   unsigned getNumVariableLengthResults() const;
124 
125   // Op attribute iterators.
126   using attribute_iterator = const NamedAttribute *;
127   attribute_iterator attribute_begin() const;
128   attribute_iterator attribute_end() const;
129   llvm::iterator_range<attribute_iterator> getAttributes() const;
130 
getNumAttributes()131   int getNumAttributes() const { return attributes.size(); }
getNumNativeAttributes()132   int getNumNativeAttributes() const { return numNativeAttributes; }
133 
134   // Op attribute accessors.
getAttribute(int index)135   NamedAttribute &getAttribute(int index) { return attributes[index]; }
136 
137   // Op operand iterators.
138   value_iterator operand_begin();
139   value_iterator operand_end();
140   value_range getOperands();
141 
getNumOperands()142   int getNumOperands() const { return operands.size(); }
getOperand(int index)143   NamedTypeConstraint &getOperand(int index) { return operands[index]; }
getOperand(int index)144   const NamedTypeConstraint &getOperand(int index) const {
145     return operands[index];
146   }
147 
148   // Returns the number of variadic operands in this operation.
149   unsigned getNumVariableLengthOperands() const;
150 
151   // Returns the total number of arguments.
getNumArgs()152   int getNumArgs() const { return arguments.size(); }
153 
154   using arg_iterator = const Argument *;
155   using arg_range = llvm::iterator_range<arg_iterator>;
156 
157   // Op argument (attribute or operand) iterators.
158   arg_iterator arg_begin() const;
159   arg_iterator arg_end() const;
160   arg_range getArgs() const;
161 
162   // Op argument (attribute or operand) accessors.
163   Argument getArg(int index) const;
164   StringRef getArgName(int index) const;
165   var_decorator_range getArgDecorators(int index) const;
166 
167   // Returns the trait wrapper for the given MLIR C++ `trait`.
168   // TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
169   // requiring the raw MLIR trait here.
170   const OpTrait *getTrait(llvm::StringRef trait) const;
171 
172   // Regions.
173   using const_region_iterator = const NamedRegion *;
174   const_region_iterator region_begin() const;
175   const_region_iterator region_end() const;
176   llvm::iterator_range<const_region_iterator> getRegions() const;
177 
178   // Returns the number of regions.
179   unsigned getNumRegions() const;
180   // Returns the `index`-th region.
181   const NamedRegion &getRegion(unsigned index) const;
182 
183   // Returns the number of variadic regions in this operation.
184   unsigned getNumVariadicRegions() const;
185 
186   // Successors.
187   using const_successor_iterator = const NamedSuccessor *;
188   const_successor_iterator successor_begin() const;
189   const_successor_iterator successor_end() const;
190   llvm::iterator_range<const_successor_iterator> getSuccessors() const;
191 
192   // Returns the number of successors.
193   unsigned getNumSuccessors() const;
194   // Returns the `index`-th successor.
195   const NamedSuccessor &getSuccessor(unsigned index) const;
196 
197   // Returns the number of variadic successors in this operation.
198   unsigned getNumVariadicSuccessors() const;
199 
200   // Trait.
201   using const_trait_iterator = const OpTrait *;
202   const_trait_iterator trait_begin() const;
203   const_trait_iterator trait_end() const;
204   llvm::iterator_range<const_trait_iterator> getTraits() const;
205 
206   ArrayRef<llvm::SMLoc> getLoc() const;
207 
208   // Query functions for the documentation of the operator.
209   bool hasDescription() const;
210   StringRef getDescription() const;
211   bool hasSummary() const;
212   StringRef getSummary() const;
213 
214   // Query functions for the assembly format of the operator.
215   bool hasAssemblyFormat() const;
216   StringRef getAssemblyFormat() const;
217 
218   // Returns this op's extra class declaration code.
219   StringRef getExtraClassDeclaration() const;
220 
221   // Returns the Tablegen definition this operator was constructed from.
222   // TODO: do not expose the TableGen record, this is a temporary solution to
223   // OpEmitter requiring a Record because Operator does not provide enough
224   // methods.
225   const llvm::Record &getDef() const;
226 
227   // Returns the dialect of the op.
getDialect()228   const Dialect &getDialect() const { return dialect; }
229 
230   // Prints the contents in this operator to the given `os`. This is used for
231   // debugging purposes.
232   void print(llvm::raw_ostream &os) const;
233 
234   // Return whether all the result types are known.
allResultTypesKnown()235   bool allResultTypesKnown() const { return allResultsHaveKnownTypes; };
236 
237   // Pair representing either a index to an argument or a type constraint. Only
238   // one of these entries should have the non-default value.
239   struct ArgOrType {
ArgOrTypeArgOrType240     explicit ArgOrType(int index) : index(index), constraint(None) {}
ArgOrTypeArgOrType241     explicit ArgOrType(TypeConstraint constraint)
242         : index(None), constraint(constraint) {}
isArgArgOrType243     bool isArg() const {
244       assert(constraint.hasValue() ^ index.hasValue());
245       return index.hasValue();
246     }
isTypeArgOrType247     bool isType() const {
248       assert(constraint.hasValue() ^ index.hasValue());
249       return constraint.hasValue();
250     }
251 
getArgArgOrType252     int getArg() const { return *index; }
getTypeArgOrType253     TypeConstraint getType() const { return *constraint; }
254 
255   private:
256     Optional<int> index;
257     Optional<TypeConstraint> constraint;
258   };
259 
260   // Return all arguments or type constraints with same type as result[index].
261   // Requires: all result types are known.
262   ArrayRef<ArgOrType> getSameTypeAsResult(int index) const;
263 
264   // Pair consisting kind of argument and index into operands or attributes.
265   struct OperandOrAttribute {
266     enum class Kind { Operand, Attribute };
OperandOrAttributeOperandOrAttribute267     OperandOrAttribute(Kind kind, int index) {
268       packed = (index << 1) & (kind == Kind::Attribute);
269     }
operandOrAttributeIndexOperandOrAttribute270     int operandOrAttributeIndex() const { return (packed >> 1); }
kindOperandOrAttribute271     Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
272 
273   private:
274     int packed;
275   };
276 
277   // Returns the OperandOrAttribute corresponding to the index.
278   OperandOrAttribute getArgToOperandOrAttribute(int index) const;
279 
280 private:
281   // Populates the vectors containing operands, attributes, results and traits.
282   void populateOpStructure();
283 
284   // Populates type inference info (mostly equality) with input a mapping from
285   // names to indices for arguments and results.
286   void populateTypeInferenceInfo(
287       const llvm::StringMap<int> &argumentsAndResultsIndex);
288 
289   // The dialect of this op.
290   Dialect dialect;
291 
292   // The unqualified C++ class name of the op.
293   StringRef cppClassName;
294 
295   // The operands of the op.
296   SmallVector<NamedTypeConstraint, 4> operands;
297 
298   // The attributes of the op.  Contains native attributes (corresponding to the
299   // actual stored attributed of the operation) followed by derived attributes
300   // (corresponding to dynamic properties of the operation that are computed
301   // upon request).
302   SmallVector<NamedAttribute, 4> attributes;
303 
304   // The arguments of the op (operands and native attributes).
305   SmallVector<Argument, 4> arguments;
306 
307   // The results of the op.
308   SmallVector<NamedTypeConstraint, 4> results;
309 
310   // The successors of this op.
311   SmallVector<NamedSuccessor, 0> successors;
312 
313   // The traits of the op.
314   SmallVector<OpTrait, 4> traits;
315 
316   // The regions of this op.
317   SmallVector<NamedRegion, 1> regions;
318 
319   // The argument with the same type as the result.
320   SmallVector<SmallVector<ArgOrType, 2>, 4> resultTypeMapping;
321 
322   // Map from argument to attribute or operand number.
323   SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
324 
325   // The number of native attributes stored in the leading positions of
326   // `attributes`.
327   int numNativeAttributes;
328 
329   // The TableGen definition of this op.
330   const llvm::Record &def;
331 
332   // Whether the type of all results are known.
333   bool allResultsHaveKnownTypes;
334 };
335 
336 } // end namespace tblgen
337 } // end namespace mlir
338 
339 #endif // MLIR_TABLEGEN_OPERATOR_H_
340