1 //===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines several classes for Op C++ code emission. They are only
10 // expected to be used by MLIR TableGen backends.
11 //
12 // We emit the op declaration and definition into separate files: *Ops.h.inc
13 // and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
14 // the latter for dialect *Ops.cpp. This way provides a cleaner interface.
15 //
16 // In order to do this split, we need to track method signature and
17 // implementation logic separately. Signature information is used for both
18 // declaration and definition, while implementation logic is only for
19 // definition. So we have the following classes for C++ code emission.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #ifndef MLIR_TABLEGEN_OPCLASS_H_
24 #define MLIR_TABLEGEN_OPCLASS_H_
25 
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 #include <set>
34 #include <string>
35 
36 namespace mlir {
37 namespace tblgen {
38 class FmtObjectBase;
39 
40 // Class for holding a single parameter of an op's method for C++ code emission.
41 class OpMethodParameter {
42 public:
43   // Properties (qualifiers) for the parameter.
44   enum Property {
45     PP_None = 0x0,
46     PP_Optional = 0x1,
47   };
48 
49   OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "",
50                     Property properties = PP_None)
type(type)51       : type(type), name(name), defaultValue(defaultValue),
52         properties(properties) {}
53 
OpMethodParameter(StringRef type,StringRef name,Property property)54   OpMethodParameter(StringRef type, StringRef name, Property property)
55       : OpMethodParameter(type, name, "", property) {}
56 
57   // Writes the parameter as a part of a method declaration to `os`.
writeDeclTo(raw_ostream & os)58   void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
59 
60   // Writes the parameter as a part of a method definition to `os`
writeDefTo(raw_ostream & os)61   void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
62 
getType()63   const std::string &getType() const { return type; }
hasDefaultValue()64   bool hasDefaultValue() const { return !defaultValue.empty(); }
65 
66 private:
67   void writeTo(raw_ostream &os, bool emitDefault) const;
68 
69   std::string type;
70   std::string name;
71   std::string defaultValue;
72   Property properties;
73 };
74 
75 // Base class for holding parameters of an op's method for C++ code emission.
76 class OpMethodParameters {
77 public:
78   // Discriminator for LLVM-style RTTI.
79   enum ParamsKind {
80     // Separate type and name for each parameter is not known.
81     PK_Unresolved,
82     // Each parameter is resolved to a type and name.
83     PK_Resolved,
84   };
85 
OpMethodParameters(ParamsKind kind)86   OpMethodParameters(ParamsKind kind) : kind(kind) {}
~OpMethodParameters()87   virtual ~OpMethodParameters() {}
88 
89   // LLVM-style RTTI support.
getKind()90   ParamsKind getKind() const { return kind; }
91 
92   // Writes the parameters as a part of a method declaration to `os`.
93   virtual void writeDeclTo(raw_ostream &os) const = 0;
94 
95   // Writes the parameters as a part of a method definition to `os`
96   virtual void writeDefTo(raw_ostream &os) const = 0;
97 
98   // Factory methods to create the correct type of `OpMethodParameters`
99   // object based on the arguments.
100   static std::unique_ptr<OpMethodParameters> create();
101 
102   static std::unique_ptr<OpMethodParameters> create(StringRef params);
103 
104   static std::unique_ptr<OpMethodParameters>
105   create(llvm::SmallVectorImpl<OpMethodParameter> &&params);
106 
107   static std::unique_ptr<OpMethodParameters>
108   create(StringRef type, StringRef name, StringRef defaultValue = "");
109 
110 private:
111   const ParamsKind kind;
112 };
113 
114 // Class for holding unresolved parameters.
115 class OpMethodUnresolvedParameters : public OpMethodParameters {
116 public:
OpMethodUnresolvedParameters(StringRef params)117   OpMethodUnresolvedParameters(StringRef params)
118       : OpMethodParameters(PK_Unresolved), parameters(params) {}
119 
120   // write the parameters as a part of a method declaration to the given `os`.
121   void writeDeclTo(raw_ostream &os) const override;
122 
123   // write the parameters as a part of a method definition to the given `os`
124   void writeDefTo(raw_ostream &os) const override;
125 
126   // LLVM-style RTTI support.
classof(const OpMethodParameters * params)127   static bool classof(const OpMethodParameters *params) {
128     return params->getKind() == PK_Unresolved;
129   }
130 
131 private:
132   std::string parameters;
133 };
134 
135 // Class for holding resolved parameters.
136 class OpMethodResolvedParameters : public OpMethodParameters {
137 public:
OpMethodResolvedParameters()138   OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {}
139 
OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> && params)140   OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> &&params)
141       : OpMethodParameters(PK_Resolved) {
142     for (OpMethodParameter &param : params)
143       parameters.emplace_back(std::move(param));
144   }
145 
OpMethodResolvedParameters(StringRef type,StringRef name,StringRef defaultValue)146   OpMethodResolvedParameters(StringRef type, StringRef name,
147                              StringRef defaultValue)
148       : OpMethodParameters(PK_Resolved) {
149     parameters.emplace_back(type, name, defaultValue);
150   }
151 
152   // Returns the number of parameters.
getNumParameters()153   size_t getNumParameters() const { return parameters.size(); }
154 
155   // Returns if this method makes the `other` method redundant. Note that this
156   // is more than just finding conflicting methods. This method determines if
157   // the 2 set of parameters are conflicting and if so, returns true if this
158   // method has a more general set of parameters that can replace all possible
159   // calls to the `other` method.
160   bool makesRedundant(const OpMethodResolvedParameters &other) const;
161 
162   // write the parameters as a part of a method declaration to the given `os`.
163   void writeDeclTo(raw_ostream &os) const override;
164 
165   // write the parameters as a part of a method definition to the given `os`
166   void writeDefTo(raw_ostream &os) const override;
167 
168   // LLVM-style RTTI support.
classof(const OpMethodParameters * params)169   static bool classof(const OpMethodParameters *params) {
170     return params->getKind() == PK_Resolved;
171   }
172 
173 private:
174   llvm::SmallVector<OpMethodParameter, 4> parameters;
175 };
176 
177 // Class for holding the signature of an op's method for C++ code emission
178 class OpMethodSignature {
179 public:
180   template <typename... Args>
OpMethodSignature(StringRef retType,StringRef name,Args &&...args)181   OpMethodSignature(StringRef retType, StringRef name, Args &&...args)
182       : returnType(retType), methodName(name),
183         parameters(OpMethodParameters::create(std::forward<Args>(args)...)) {}
184   OpMethodSignature(OpMethodSignature &&) = default;
185 
186   // Returns if a method with this signature makes a method with `other`
187   // signature redundant. Only supports resolved parameters.
188   bool makesRedundant(const OpMethodSignature &other) const;
189 
190   // Returns the number of parameters (for resolved parameters).
getNumParameters()191   size_t getNumParameters() const {
192     return cast<OpMethodResolvedParameters>(parameters.get())
193         ->getNumParameters();
194   }
195 
196   // Returns the name of the method.
getName()197   StringRef getName() const { return methodName; }
198 
199   // Writes the signature as a method declaration to the given `os`.
200   void writeDeclTo(raw_ostream &os) const;
201 
202   // Writes the signature as the start of a method definition to the given `os`.
203   // `namePrefix` is the prefix to be prepended to the method name (typically
204   // namespaces for qualifying the method definition).
205   void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
206 
207 private:
208   std::string returnType;
209   std::string methodName;
210   std::unique_ptr<OpMethodParameters> parameters;
211 };
212 
213 // Class for holding the body of an op's method for C++ code emission
214 class OpMethodBody {
215 public:
216   explicit OpMethodBody(bool declOnly);
217 
218   OpMethodBody &operator<<(Twine content);
219   OpMethodBody &operator<<(int content);
220   OpMethodBody &operator<<(const FmtObjectBase &content);
221 
222   void writeTo(raw_ostream &os) const;
223 
224 private:
225   // Whether this class should record method body.
226   bool isEffective;
227   std::string body;
228 };
229 
230 // Class for holding an op's method for C++ code emission
231 class OpMethod {
232 public:
233   // Properties (qualifiers) of class methods. Bitfield is used here to help
234   // querying properties.
235   enum Property {
236     MP_None = 0x0,
237     MP_Static = 0x1,
238     MP_Constructor = 0x2,
239     MP_Private = 0x4,
240     MP_Declaration = 0x8,
241     MP_Inline = 0x10,
242     MP_Constexpr = 0x20 | MP_Inline,
243     MP_StaticDeclaration = MP_Static | MP_Declaration,
244   };
245 
246   template <typename... Args>
OpMethod(StringRef retType,StringRef name,Property property,unsigned id,Args &&...args)247   OpMethod(StringRef retType, StringRef name, Property property, unsigned id,
248            Args &&...args)
249       : properties(property),
250         methodSignature(retType, name, std::forward<Args>(args)...),
251         methodBody(properties & MP_Declaration), id(id) {}
252 
253   OpMethod(OpMethod &&) = default;
254 
255   virtual ~OpMethod() = default;
256 
body()257   OpMethodBody &body() { return methodBody; }
258 
259   // Returns true if this is a static method.
isStatic()260   bool isStatic() const { return properties & MP_Static; }
261 
262   // Returns true if this is a private method.
isPrivate()263   bool isPrivate() const { return properties & MP_Private; }
264 
265   // Returns true if this is an inline method.
isInline()266   bool isInline() const { return properties & MP_Inline; }
267 
268   // Returns the name of this method.
getName()269   StringRef getName() const { return methodSignature.getName(); }
270 
271   // Returns the ID for this method
getID()272   unsigned getID() const { return id; }
273 
274   // Returns if this method makes the `other` method redundant.
makesRedundant(const OpMethod & other)275   bool makesRedundant(const OpMethod &other) const {
276     return methodSignature.makesRedundant(other.methodSignature);
277   }
278 
279   // Writes the method as a declaration to the given `os`.
280   virtual void writeDeclTo(raw_ostream &os) const;
281 
282   // Writes the method as a definition to the given `os`. `namePrefix` is the
283   // prefix to be prepended to the method name (typically namespaces for
284   // qualifying the method definition).
285   virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
286 
287 protected:
288   Property properties;
289   OpMethodSignature methodSignature;
290   OpMethodBody methodBody;
291   const unsigned id;
292 };
293 
294 // Class for holding an op's constructor method for C++ code emission.
295 class OpConstructor : public OpMethod {
296 public:
297   template <typename... Args>
OpConstructor(StringRef className,Property property,unsigned id,Args &&...args)298   OpConstructor(StringRef className, Property property, unsigned id,
299                 Args &&...args)
300       : OpMethod("", className, property, id, std::forward<Args>(args)...) {}
301 
302   // Add member initializer to constructor initializing `name` with `value`.
303   void addMemberInitializer(StringRef name, StringRef value);
304 
305   // Writes the method as a definition to the given `os`. `namePrefix` is the
306   // prefix to be prepended to the method name (typically namespaces for
307   // qualifying the method definition).
308   void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
309 
310 private:
311   // Member initializers.
312   std::string memberInitializers;
313 };
314 
315 // A class used to emit C++ classes from Tablegen.  Contains a list of public
316 // methods and a list of private fields to be emitted.
317 class Class {
318 public:
319   explicit Class(StringRef name);
320 
321   // Adds a new method to this class and prune redundant methods. Returns null
322   // if the method was not added (because an existing method would make it
323   // redundant), else returns a pointer to the added method. Note that this call
324   // may also delete existing methods that are made redundant by a method to the
325   // class.
326   template <typename... Args>
addMethodAndPrune(StringRef retType,StringRef name,OpMethod::Property properties,Args &&...args)327   OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
328                               OpMethod::Property properties, Args &&...args) {
329     auto newMethod = std::make_unique<OpMethod>(
330         retType, name, properties, nextMethodID++, std::forward<Args>(args)...);
331     return addMethodAndPrune(methods, std::move(newMethod));
332   }
333 
334   template <typename... Args>
addMethodAndPrune(StringRef retType,StringRef name,Args &&...args)335   OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
336                               Args &&...args) {
337     return addMethodAndPrune(retType, name, OpMethod::MP_None,
338                              std::forward<Args>(args)...);
339   }
340 
341   template <typename... Args>
addConstructorAndPrune(Args &&...args)342   OpConstructor *addConstructorAndPrune(Args &&...args) {
343     auto newConstructor = std::make_unique<OpConstructor>(
344         getClassName(), OpMethod::MP_Constructor, nextMethodID++,
345         std::forward<Args>(args)...);
346     return addMethodAndPrune(constructors, std::move(newConstructor));
347   }
348 
349   // Creates a new field in this class.
350   void newField(StringRef type, StringRef name, StringRef defaultValue = "");
351 
352   // Writes this op's class as a declaration to the given `os`.
353   void writeDeclTo(raw_ostream &os) const;
354   // Writes the method definitions in this op's class to the given `os`.
355   void writeDefTo(raw_ostream &os) const;
356 
357   // Returns the C++ class name of the op.
getClassName()358   StringRef getClassName() const { return className; }
359 
360 protected:
361   // Get a list of all the methods to emit, filtering out hidden ones.
forAllMethods(llvm::function_ref<void (const OpMethod &)> func)362   void forAllMethods(llvm::function_ref<void(const OpMethod &)> func) const {
363     using ConsRef = const std::unique_ptr<OpConstructor> &;
364     using MethodRef = const std::unique_ptr<OpMethod> &;
365     llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); });
366     llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); });
367   }
368 
369   // For deterministic code generation, keep methods sorted in the order in
370   // which they were generated.
371   template <typename MethodTy>
372   struct MethodCompare {
operatorMethodCompare373     bool operator()(const std::unique_ptr<MethodTy> &x,
374                     const std::unique_ptr<MethodTy> &y) const {
375       return x->getID() < y->getID();
376     }
377   };
378 
379   template <typename MethodTy>
380   using MethodSet =
381       std::set<std::unique_ptr<MethodTy>, MethodCompare<MethodTy>>;
382 
383   template <typename MethodTy>
addMethodAndPrune(MethodSet<MethodTy> & set,std::unique_ptr<MethodTy> && newMethod)384   MethodTy *addMethodAndPrune(MethodSet<MethodTy> &set,
385                               std::unique_ptr<MethodTy> &&newMethod) {
386     // Check if the new method will be made redundant by existing methods.
387     for (auto &method : set)
388       if (method->makesRedundant(*newMethod))
389         return nullptr;
390 
391     // We can add this a method to the set. Prune any existing methods that will
392     // be made redundant by adding this new method. Note that the redundant
393     // check between two methods is more than a conflict check. makesRedundant()
394     // below will check if the new method conflicts with an existing method and
395     // if so, returns true if the new method makes the existing method redundant
396     // because all calls to the existing method can be subsumed by the new
397     // method. So makesRedundant() does a combined job of finding conflicts and
398     // deciding which of the 2 conflicting methods survive.
399     //
400     // Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing
401     // it manually here.
402     for (auto it = set.begin(), end = set.end(); it != end;) {
403       if (newMethod->makesRedundant(*(it->get())))
404         it = set.erase(it);
405       else
406         ++it;
407     }
408 
409     MethodTy *ret = newMethod.get();
410     set.insert(std::move(newMethod));
411     return ret;
412   }
413 
414   std::string className;
415   MethodSet<OpConstructor> constructors;
416   MethodSet<OpMethod> methods;
417   unsigned nextMethodID = 0;
418   SmallVector<std::string, 4> fields;
419 };
420 
421 // Class for holding an op for C++ code emission
422 class OpClass : public Class {
423 public:
424   explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
425 
426   // Adds an op trait.
427   void addTrait(Twine trait);
428 
429   // Writes this op's class as a declaration to the given `os`.  Redefines
430   // Class::writeDeclTo to also emit traits and extra class declarations.
431   void writeDeclTo(raw_ostream &os) const;
432 
433 private:
434   StringRef extraClassDeclaration;
435   SmallVector<std::string, 4> traitsVec;
436   StringSet<> traitsSet;
437 };
438 
439 } // namespace tblgen
440 } // namespace mlir
441 
442 #endif // MLIR_TABLEGEN_OPCLASS_H_
443