1 //===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines several classes for Op C++ code emission. They are only 10 // expected to be used by MLIR TableGen backends. 11 // 12 // We emit the op declaration and definition into separate files: *Ops.h.inc 13 // and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and 14 // the latter for dialect *Ops.cpp. This way provides a cleaner interface. 15 // 16 // In order to do this split, we need to track method signature and 17 // implementation logic separately. Signature information is used for both 18 // declaration and definition, while implementation logic is only for 19 // definition. So we have the following classes for C++ code emission. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #ifndef MLIR_TABLEGEN_OPCLASS_H_ 24 #define MLIR_TABLEGEN_OPCLASS_H_ 25 26 #include "mlir/Support/LLVM.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/ADT/StringSet.h" 31 #include "llvm/Support/raw_ostream.h" 32 33 #include <set> 34 #include <string> 35 36 namespace mlir { 37 namespace tblgen { 38 class FmtObjectBase; 39 40 // Class for holding a single parameter of an op's method for C++ code emission. 41 class OpMethodParameter { 42 public: 43 // Properties (qualifiers) for the parameter. 44 enum Property { 45 PP_None = 0x0, 46 PP_Optional = 0x1, 47 }; 48 49 OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "", 50 Property properties = PP_None) type(type)51 : type(type), name(name), defaultValue(defaultValue), 52 properties(properties) {} 53 OpMethodParameter(StringRef type,StringRef name,Property property)54 OpMethodParameter(StringRef type, StringRef name, Property property) 55 : OpMethodParameter(type, name, "", property) {} 56 57 // Writes the parameter as a part of a method declaration to `os`. writeDeclTo(raw_ostream & os)58 void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); } 59 60 // Writes the parameter as a part of a method definition to `os` writeDefTo(raw_ostream & os)61 void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); } 62 getType()63 const std::string &getType() const { return type; } hasDefaultValue()64 bool hasDefaultValue() const { return !defaultValue.empty(); } 65 66 private: 67 void writeTo(raw_ostream &os, bool emitDefault) const; 68 69 std::string type; 70 std::string name; 71 std::string defaultValue; 72 Property properties; 73 }; 74 75 // Base class for holding parameters of an op's method for C++ code emission. 76 class OpMethodParameters { 77 public: 78 // Discriminator for LLVM-style RTTI. 79 enum ParamsKind { 80 // Separate type and name for each parameter is not known. 81 PK_Unresolved, 82 // Each parameter is resolved to a type and name. 83 PK_Resolved, 84 }; 85 OpMethodParameters(ParamsKind kind)86 OpMethodParameters(ParamsKind kind) : kind(kind) {} ~OpMethodParameters()87 virtual ~OpMethodParameters() {} 88 89 // LLVM-style RTTI support. getKind()90 ParamsKind getKind() const { return kind; } 91 92 // Writes the parameters as a part of a method declaration to `os`. 93 virtual void writeDeclTo(raw_ostream &os) const = 0; 94 95 // Writes the parameters as a part of a method definition to `os` 96 virtual void writeDefTo(raw_ostream &os) const = 0; 97 98 // Factory methods to create the correct type of `OpMethodParameters` 99 // object based on the arguments. 100 static std::unique_ptr<OpMethodParameters> create(); 101 102 static std::unique_ptr<OpMethodParameters> create(StringRef params); 103 104 static std::unique_ptr<OpMethodParameters> 105 create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms); 106 107 static std::unique_ptr<OpMethodParameters> 108 create(StringRef type, StringRef name, StringRef defaultValue = ""); 109 110 private: 111 const ParamsKind kind; 112 }; 113 114 // Class for holding unresolved parameters. 115 class OpMethodUnresolvedParameters : public OpMethodParameters { 116 public: OpMethodUnresolvedParameters(StringRef params)117 OpMethodUnresolvedParameters(StringRef params) 118 : OpMethodParameters(PK_Unresolved), parameters(params) {} 119 120 // write the parameters as a part of a method declaration to the given `os`. 121 void writeDeclTo(raw_ostream &os) const override; 122 123 // write the parameters as a part of a method definition to the given `os` 124 void writeDefTo(raw_ostream &os) const override; 125 126 // LLVM-style RTTI support. classof(const OpMethodParameters * params)127 static bool classof(const OpMethodParameters *params) { 128 return params->getKind() == PK_Unresolved; 129 } 130 131 private: 132 std::string parameters; 133 }; 134 135 // Class for holding resolved parameters. 136 class OpMethodResolvedParameters : public OpMethodParameters { 137 public: OpMethodResolvedParameters()138 OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {} 139 OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> && params)140 OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> &¶ms) 141 : OpMethodParameters(PK_Resolved) { 142 for (OpMethodParameter ¶m : params) 143 parameters.emplace_back(std::move(param)); 144 } 145 OpMethodResolvedParameters(StringRef type,StringRef name,StringRef defaultValue)146 OpMethodResolvedParameters(StringRef type, StringRef name, 147 StringRef defaultValue) 148 : OpMethodParameters(PK_Resolved) { 149 parameters.emplace_back(type, name, defaultValue); 150 } 151 152 // Returns the number of parameters. getNumParameters()153 size_t getNumParameters() const { return parameters.size(); } 154 155 // Returns if this method makes the `other` method redundant. Note that this 156 // is more than just finding conflicting methods. This method determines if 157 // the 2 set of parameters are conflicting and if so, returns true if this 158 // method has a more general set of parameters that can replace all possible 159 // calls to the `other` method. 160 bool makesRedundant(const OpMethodResolvedParameters &other) const; 161 162 // write the parameters as a part of a method declaration to the given `os`. 163 void writeDeclTo(raw_ostream &os) const override; 164 165 // write the parameters as a part of a method definition to the given `os` 166 void writeDefTo(raw_ostream &os) const override; 167 168 // LLVM-style RTTI support. classof(const OpMethodParameters * params)169 static bool classof(const OpMethodParameters *params) { 170 return params->getKind() == PK_Resolved; 171 } 172 173 private: 174 llvm::SmallVector<OpMethodParameter, 4> parameters; 175 }; 176 177 // Class for holding the signature of an op's method for C++ code emission 178 class OpMethodSignature { 179 public: 180 template <typename... Args> OpMethodSignature(StringRef retType,StringRef name,Args &&...args)181 OpMethodSignature(StringRef retType, StringRef name, Args &&...args) 182 : returnType(retType), methodName(name), 183 parameters(OpMethodParameters::create(std::forward<Args>(args)...)) {} 184 OpMethodSignature(OpMethodSignature &&) = default; 185 186 // Returns if a method with this signature makes a method with `other` 187 // signature redundant. Only supports resolved parameters. 188 bool makesRedundant(const OpMethodSignature &other) const; 189 190 // Returns the number of parameters (for resolved parameters). getNumParameters()191 size_t getNumParameters() const { 192 return cast<OpMethodResolvedParameters>(parameters.get()) 193 ->getNumParameters(); 194 } 195 196 // Returns the name of the method. getName()197 StringRef getName() const { return methodName; } 198 199 // Writes the signature as a method declaration to the given `os`. 200 void writeDeclTo(raw_ostream &os) const; 201 202 // Writes the signature as the start of a method definition to the given `os`. 203 // `namePrefix` is the prefix to be prepended to the method name (typically 204 // namespaces for qualifying the method definition). 205 void writeDefTo(raw_ostream &os, StringRef namePrefix) const; 206 207 private: 208 std::string returnType; 209 std::string methodName; 210 std::unique_ptr<OpMethodParameters> parameters; 211 }; 212 213 // Class for holding the body of an op's method for C++ code emission 214 class OpMethodBody { 215 public: 216 explicit OpMethodBody(bool declOnly); 217 218 OpMethodBody &operator<<(Twine content); 219 OpMethodBody &operator<<(int content); 220 OpMethodBody &operator<<(const FmtObjectBase &content); 221 222 void writeTo(raw_ostream &os) const; 223 224 private: 225 // Whether this class should record method body. 226 bool isEffective; 227 std::string body; 228 }; 229 230 // Class for holding an op's method for C++ code emission 231 class OpMethod { 232 public: 233 // Properties (qualifiers) of class methods. Bitfield is used here to help 234 // querying properties. 235 enum Property { 236 MP_None = 0x0, 237 MP_Static = 0x1, 238 MP_Constructor = 0x2, 239 MP_Private = 0x4, 240 MP_Declaration = 0x8, 241 MP_Inline = 0x10, 242 MP_Constexpr = 0x20 | MP_Inline, 243 MP_StaticDeclaration = MP_Static | MP_Declaration, 244 }; 245 246 template <typename... Args> OpMethod(StringRef retType,StringRef name,Property property,unsigned id,Args &&...args)247 OpMethod(StringRef retType, StringRef name, Property property, unsigned id, 248 Args &&...args) 249 : properties(property), 250 methodSignature(retType, name, std::forward<Args>(args)...), 251 methodBody(properties & MP_Declaration), id(id) {} 252 253 OpMethod(OpMethod &&) = default; 254 255 virtual ~OpMethod() = default; 256 body()257 OpMethodBody &body() { return methodBody; } 258 259 // Returns true if this is a static method. isStatic()260 bool isStatic() const { return properties & MP_Static; } 261 262 // Returns true if this is a private method. isPrivate()263 bool isPrivate() const { return properties & MP_Private; } 264 265 // Returns true if this is an inline method. isInline()266 bool isInline() const { return properties & MP_Inline; } 267 268 // Returns the name of this method. getName()269 StringRef getName() const { return methodSignature.getName(); } 270 271 // Returns the ID for this method getID()272 unsigned getID() const { return id; } 273 274 // Returns if this method makes the `other` method redundant. makesRedundant(const OpMethod & other)275 bool makesRedundant(const OpMethod &other) const { 276 return methodSignature.makesRedundant(other.methodSignature); 277 } 278 279 // Writes the method as a declaration to the given `os`. 280 virtual void writeDeclTo(raw_ostream &os) const; 281 282 // Writes the method as a definition to the given `os`. `namePrefix` is the 283 // prefix to be prepended to the method name (typically namespaces for 284 // qualifying the method definition). 285 virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const; 286 287 protected: 288 Property properties; 289 OpMethodSignature methodSignature; 290 OpMethodBody methodBody; 291 const unsigned id; 292 }; 293 294 // Class for holding an op's constructor method for C++ code emission. 295 class OpConstructor : public OpMethod { 296 public: 297 template <typename... Args> OpConstructor(StringRef className,Property property,unsigned id,Args &&...args)298 OpConstructor(StringRef className, Property property, unsigned id, 299 Args &&...args) 300 : OpMethod("", className, property, id, std::forward<Args>(args)...) {} 301 302 // Add member initializer to constructor initializing `name` with `value`. 303 void addMemberInitializer(StringRef name, StringRef value); 304 305 // Writes the method as a definition to the given `os`. `namePrefix` is the 306 // prefix to be prepended to the method name (typically namespaces for 307 // qualifying the method definition). 308 void writeDefTo(raw_ostream &os, StringRef namePrefix) const override; 309 310 private: 311 // Member initializers. 312 std::string memberInitializers; 313 }; 314 315 // A class used to emit C++ classes from Tablegen. Contains a list of public 316 // methods and a list of private fields to be emitted. 317 class Class { 318 public: 319 explicit Class(StringRef name); 320 321 // Adds a new method to this class and prune redundant methods. Returns null 322 // if the method was not added (because an existing method would make it 323 // redundant), else returns a pointer to the added method. Note that this call 324 // may also delete existing methods that are made redundant by a method to the 325 // class. 326 template <typename... Args> addMethodAndPrune(StringRef retType,StringRef name,OpMethod::Property properties,Args &&...args)327 OpMethod *addMethodAndPrune(StringRef retType, StringRef name, 328 OpMethod::Property properties, Args &&...args) { 329 auto newMethod = std::make_unique<OpMethod>( 330 retType, name, properties, nextMethodID++, std::forward<Args>(args)...); 331 return addMethodAndPrune(methods, std::move(newMethod)); 332 } 333 334 template <typename... Args> addMethodAndPrune(StringRef retType,StringRef name,Args &&...args)335 OpMethod *addMethodAndPrune(StringRef retType, StringRef name, 336 Args &&...args) { 337 return addMethodAndPrune(retType, name, OpMethod::MP_None, 338 std::forward<Args>(args)...); 339 } 340 341 template <typename... Args> addConstructorAndPrune(Args &&...args)342 OpConstructor *addConstructorAndPrune(Args &&...args) { 343 auto newConstructor = std::make_unique<OpConstructor>( 344 getClassName(), OpMethod::MP_Constructor, nextMethodID++, 345 std::forward<Args>(args)...); 346 return addMethodAndPrune(constructors, std::move(newConstructor)); 347 } 348 349 // Creates a new field in this class. 350 void newField(StringRef type, StringRef name, StringRef defaultValue = ""); 351 352 // Writes this op's class as a declaration to the given `os`. 353 void writeDeclTo(raw_ostream &os) const; 354 // Writes the method definitions in this op's class to the given `os`. 355 void writeDefTo(raw_ostream &os) const; 356 357 // Returns the C++ class name of the op. getClassName()358 StringRef getClassName() const { return className; } 359 360 protected: 361 // Get a list of all the methods to emit, filtering out hidden ones. forAllMethods(llvm::function_ref<void (const OpMethod &)> func)362 void forAllMethods(llvm::function_ref<void(const OpMethod &)> func) const { 363 using ConsRef = const std::unique_ptr<OpConstructor> &; 364 using MethodRef = const std::unique_ptr<OpMethod> &; 365 llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); }); 366 llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); }); 367 } 368 369 // For deterministic code generation, keep methods sorted in the order in 370 // which they were generated. 371 template <typename MethodTy> 372 struct MethodCompare { operatorMethodCompare373 bool operator()(const std::unique_ptr<MethodTy> &x, 374 const std::unique_ptr<MethodTy> &y) const { 375 return x->getID() < y->getID(); 376 } 377 }; 378 379 template <typename MethodTy> 380 using MethodSet = 381 std::set<std::unique_ptr<MethodTy>, MethodCompare<MethodTy>>; 382 383 template <typename MethodTy> addMethodAndPrune(MethodSet<MethodTy> & set,std::unique_ptr<MethodTy> && newMethod)384 MethodTy *addMethodAndPrune(MethodSet<MethodTy> &set, 385 std::unique_ptr<MethodTy> &&newMethod) { 386 // Check if the new method will be made redundant by existing methods. 387 for (auto &method : set) 388 if (method->makesRedundant(*newMethod)) 389 return nullptr; 390 391 // We can add this a method to the set. Prune any existing methods that will 392 // be made redundant by adding this new method. Note that the redundant 393 // check between two methods is more than a conflict check. makesRedundant() 394 // below will check if the new method conflicts with an existing method and 395 // if so, returns true if the new method makes the existing method redundant 396 // because all calls to the existing method can be subsumed by the new 397 // method. So makesRedundant() does a combined job of finding conflicts and 398 // deciding which of the 2 conflicting methods survive. 399 // 400 // Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing 401 // it manually here. 402 for (auto it = set.begin(), end = set.end(); it != end;) { 403 if (newMethod->makesRedundant(*(it->get()))) 404 it = set.erase(it); 405 else 406 ++it; 407 } 408 409 MethodTy *ret = newMethod.get(); 410 set.insert(std::move(newMethod)); 411 return ret; 412 } 413 414 std::string className; 415 MethodSet<OpConstructor> constructors; 416 MethodSet<OpMethod> methods; 417 unsigned nextMethodID = 0; 418 SmallVector<std::string, 4> fields; 419 }; 420 421 // Class for holding an op for C++ code emission 422 class OpClass : public Class { 423 public: 424 explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); 425 426 // Adds an op trait. 427 void addTrait(Twine trait); 428 429 // Writes this op's class as a declaration to the given `os`. Redefines 430 // Class::writeDeclTo to also emit traits and extra class declarations. 431 void writeDeclTo(raw_ostream &os) const; 432 433 private: 434 StringRef extraClassDeclaration; 435 SmallVector<std::string, 4> traitsVec; 436 StringSet<> traitsSet; 437 }; 438 439 } // namespace tblgen 440 } // namespace mlir 441 442 #endif // MLIR_TABLEGEN_OPCLASS_H_ 443