1 //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpInterfacesGen generates definitions for operation interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "DocGenUtilities.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "mlir/TableGen/Interfaces.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 #include "llvm/TableGen/TableGenBackend.h"
24
25 using namespace mlir;
26 using mlir::tblgen::Interface;
27 using mlir::tblgen::InterfaceMethod;
28 using mlir::tblgen::OpInterface;
29
30 /// Emit a string corresponding to a C++ type, followed by a space if necessary.
emitCPPType(StringRef type,raw_ostream & os)31 static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
32 type = type.trim();
33 os << type;
34 if (type.back() != '&' && type.back() != '*')
35 os << " ";
36 return os;
37 }
38
39 /// Emit the method name and argument list for the given method. If 'addThisArg'
40 /// is true, then an argument is added to the beginning of the argument list for
41 /// the concrete value.
emitMethodNameAndArgs(const InterfaceMethod & method,raw_ostream & os,StringRef valueType,bool addThisArg,bool addConst)42 static void emitMethodNameAndArgs(const InterfaceMethod &method,
43 raw_ostream &os, StringRef valueType,
44 bool addThisArg, bool addConst) {
45 os << method.getName() << '(';
46 if (addThisArg) {
47 if (addConst)
48 os << "const ";
49 os << "const Concept *impl, ";
50 emitCPPType(valueType, os)
51 << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
52 }
53 llvm::interleaveComma(method.getArguments(), os,
54 [&](const InterfaceMethod::Argument &arg) {
55 os << arg.type << " " << arg.name;
56 });
57 os << ')';
58 if (addConst)
59 os << " const";
60 }
61
62 /// Get an array of all OpInterface definitions but exclude those subclassing
63 /// "DeclareOpInterfaceMethods".
64 static std::vector<llvm::Record *>
getAllOpInterfaceDefinitions(const llvm::RecordKeeper & recordKeeper)65 getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) {
66 std::vector<llvm::Record *> defs =
67 recordKeeper.getAllDerivedDefinitions("OpInterface");
68
69 llvm::erase_if(defs, [](const llvm::Record *def) {
70 return def->isSubClassOf("DeclareOpInterfaceMethods");
71 });
72 return defs;
73 }
74
75 namespace {
76 /// This struct is the base generator used when processing tablegen interfaces.
77 class InterfaceGenerator {
78 public:
79 bool emitInterfaceDefs();
80 bool emitInterfaceDecls();
81 bool emitInterfaceDocs();
82
83 protected:
InterfaceGenerator(std::vector<llvm::Record * > && defs,raw_ostream & os)84 InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
85 : defs(std::move(defs)), os(os) {}
86
87 void emitConceptDecl(Interface &interface);
88 void emitModelDecl(Interface &interface);
89 void emitModelMethodsDef(Interface &interface);
90 void emitTraitDecl(Interface &interface, StringRef interfaceName,
91 StringRef interfaceTraitsName);
92 void emitInterfaceDecl(Interface interface);
93
94 /// The set of interface records to emit.
95 std::vector<llvm::Record *> defs;
96 // The stream to emit to.
97 raw_ostream &os;
98 /// The C++ value type of the interface, e.g. Operation*.
99 StringRef valueType;
100 /// The C++ base interface type.
101 StringRef interfaceBaseType;
102 /// The name of the typename for the value template.
103 StringRef valueTemplate;
104 /// The format context to use for methods.
105 tblgen::FmtContext nonStaticMethodFmt;
106 tblgen::FmtContext traitMethodFmt;
107 };
108
109 /// A specialized generator for attribute interfaces.
110 struct AttrInterfaceGenerator : public InterfaceGenerator {
AttrInterfaceGenerator__anonef0631450311::AttrInterfaceGenerator111 AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
112 : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"),
113 os) {
114 valueType = "::mlir::Attribute";
115 interfaceBaseType = "AttributeInterface";
116 valueTemplate = "ConcreteAttr";
117 StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())";
118 nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode);
119 traitMethodFmt.addSubst("_attr",
120 "(*static_cast<const ConcreteAttr *>(this))");
121 }
122 };
123 /// A specialized generator for operation interfaces.
124 struct OpInterfaceGenerator : public InterfaceGenerator {
OpInterfaceGenerator__anonef0631450311::OpInterfaceGenerator125 OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
126 : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) {
127 valueType = "::mlir::Operation *";
128 interfaceBaseType = "OpInterface";
129 valueTemplate = "ConcreteOp";
130 StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
131 nonStaticMethodFmt.addSubst("_this", "impl")
132 .withOp(castCode)
133 .withSelf(castCode);
134 traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
135 }
136 };
137 /// A specialized generator for type interfaces.
138 struct TypeInterfaceGenerator : public InterfaceGenerator {
TypeInterfaceGenerator__anonef0631450311::TypeInterfaceGenerator139 TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
140 : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"),
141 os) {
142 valueType = "::mlir::Type";
143 interfaceBaseType = "TypeInterface";
144 valueTemplate = "ConcreteType";
145 StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())";
146 nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode);
147 traitMethodFmt.addSubst("_type",
148 "(*static_cast<const ConcreteType *>(this))");
149 }
150 };
151 } // end anonymous namespace
152
153 //===----------------------------------------------------------------------===//
154 // GEN: Interface definitions
155 //===----------------------------------------------------------------------===//
156
emitInterfaceDef(Interface interface,StringRef valueType,raw_ostream & os)157 static void emitInterfaceDef(Interface interface, StringRef valueType,
158 raw_ostream &os) {
159 StringRef interfaceName = interface.getName();
160 StringRef cppNamespace = interface.getCppNamespace();
161 cppNamespace.consume_front("::");
162
163 // Insert the method definitions.
164 bool isOpInterface = isa<OpInterface>(interface);
165 for (auto &method : interface.getMethods()) {
166 emitCPPType(method.getReturnType(), os);
167 if (!cppNamespace.empty())
168 os << cppNamespace << "::";
169 os << interfaceName << "::";
170 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
171 /*addConst=*/!isOpInterface);
172
173 // Forward to the method on the concrete operation type.
174 os << " {\n return getImpl()->" << method.getName() << '(';
175 if (!method.isStatic()) {
176 os << "getImpl(), ";
177 os << (isOpInterface ? "getOperation()" : "*this");
178 os << (method.arg_empty() ? "" : ", ");
179 }
180 llvm::interleaveComma(
181 method.getArguments(), os,
182 [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
183 os << ");\n }\n";
184 }
185 }
186
emitInterfaceDefs()187 bool InterfaceGenerator::emitInterfaceDefs() {
188 llvm::emitSourceFileHeader("Interface Definitions", os);
189
190 for (const auto *def : defs)
191 emitInterfaceDef(Interface(def), valueType, os);
192 return false;
193 }
194
195 //===----------------------------------------------------------------------===//
196 // GEN: Interface declarations
197 //===----------------------------------------------------------------------===//
198
emitConceptDecl(Interface & interface)199 void InterfaceGenerator::emitConceptDecl(Interface &interface) {
200 os << " struct Concept {\n";
201
202 // Insert each of the pure virtual concept methods.
203 for (auto &method : interface.getMethods()) {
204 os << " ";
205 emitCPPType(method.getReturnType(), os);
206 os << "(*" << method.getName() << ")(";
207 if (!method.isStatic()) {
208 os << "const Concept *impl, ";
209 emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
210 }
211 llvm::interleaveComma(
212 method.getArguments(), os,
213 [&](const InterfaceMethod::Argument &arg) { os << arg.type; });
214 os << ");\n";
215 }
216 os << " };\n";
217 }
218
emitModelDecl(Interface & interface)219 void InterfaceGenerator::emitModelDecl(Interface &interface) {
220 // Emit the basic model and the fallback model.
221 for (const char *modelClass : {"Model", "FallbackModel"}) {
222 os << " template<typename " << valueTemplate << ">\n";
223 os << " class " << modelClass << " : public Concept {\n public:\n";
224 os << " using Interface = " << interface.getCppNamespace()
225 << (interface.getCppNamespace().empty() ? "" : "::")
226 << interface.getName() << ";\n";
227 os << " " << modelClass << "() : Concept{";
228 llvm::interleaveComma(
229 interface.getMethods(), os,
230 [&](const InterfaceMethod &method) { os << method.getName(); });
231 os << "} {}\n\n";
232
233 // Insert each of the virtual method overrides.
234 for (auto &method : interface.getMethods()) {
235 emitCPPType(method.getReturnType(), os << " static inline ");
236 emitMethodNameAndArgs(method, os, valueType,
237 /*addThisArg=*/!method.isStatic(),
238 /*addConst=*/false);
239 os << ";\n";
240 }
241 os << " };\n";
242 }
243
244 // Emit the template for the external model.
245 os << " template<typename ConcreteModel, typename " << valueTemplate
246 << ">\n";
247 os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n";
248 os << " public:\n";
249
250 // Emit declarations for methods that have default implementations. Other
251 // methods are expected to be implemented by the concrete derived model.
252 for (auto &method : interface.getMethods()) {
253 if (!method.getDefaultImplementation())
254 continue;
255 os << " ";
256 if (method.isStatic())
257 os << "static ";
258 emitCPPType(method.getReturnType(), os);
259 os << method.getName() << "(";
260 if (!method.isStatic()) {
261 emitCPPType(valueType, os);
262 os << "tablegen_opaque_val";
263 if (!method.arg_empty())
264 os << ", ";
265 }
266 llvm::interleaveComma(method.getArguments(), os,
267 [&](const InterfaceMethod::Argument &arg) {
268 emitCPPType(arg.type, os);
269 os << arg.name;
270 });
271 os << ")";
272 if (!method.isStatic())
273 os << " const";
274 os << ";\n";
275 }
276 os << " };\n";
277 }
278
emitModelMethodsDef(Interface & interface)279 void InterfaceGenerator::emitModelMethodsDef(Interface &interface) {
280 for (auto &method : interface.getMethods()) {
281 os << "template<typename " << valueTemplate << ">\n";
282 emitCPPType(method.getReturnType(), os);
283 os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
284 << valueTemplate << ">::";
285 emitMethodNameAndArgs(method, os, valueType,
286 /*addThisArg=*/!method.isStatic(),
287 /*addConst=*/false);
288 os << " {\n ";
289
290 // Check for a provided body to the function.
291 if (Optional<StringRef> body = method.getBody()) {
292 if (method.isStatic())
293 os << body->trim();
294 else
295 os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt);
296 os << "\n}\n";
297 continue;
298 }
299
300 // Forward to the method on the concrete operation type.
301 if (method.isStatic())
302 os << "return " << valueTemplate << "::";
303 else
304 os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt);
305
306 // Add the arguments to the call.
307 os << method.getName() << '(';
308 llvm::interleaveComma(
309 method.getArguments(), os,
310 [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
311 os << ");\n}\n";
312 }
313
314 for (auto &method : interface.getMethods()) {
315 os << "template<typename " << valueTemplate << ">\n";
316 emitCPPType(method.getReturnType(), os);
317 os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
318 << valueTemplate << ">::";
319 emitMethodNameAndArgs(method, os, valueType,
320 /*addThisArg=*/!method.isStatic(),
321 /*addConst=*/false);
322 os << " {\n ";
323
324 // Forward to the method on the concrete Model implementation.
325 if (method.isStatic())
326 os << "return " << valueTemplate << "::";
327 else
328 os << "return static_cast<const " << valueTemplate << " *>(impl)->";
329
330 // Add the arguments to the call.
331 os << method.getName() << '(';
332 if (!method.isStatic())
333 os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
334 llvm::interleaveComma(
335 method.getArguments(), os,
336 [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
337 os << ");\n}\n";
338 }
339
340 // Emit default implementations for the external model.
341 for (auto &method : interface.getMethods()) {
342 if (!method.getDefaultImplementation())
343 continue;
344 os << "template<typename ConcreteModel, typename " << valueTemplate
345 << ">\n";
346 emitCPPType(method.getReturnType(), os);
347 os << "detail::" << interface.getName()
348 << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
349 << ">::";
350
351 os << method.getName() << "(";
352 if (!method.isStatic()) {
353 emitCPPType(valueType, os);
354 os << "tablegen_opaque_val";
355 if (!method.arg_empty())
356 os << ", ";
357 }
358 llvm::interleaveComma(method.getArguments(), os,
359 [&](const InterfaceMethod::Argument &arg) {
360 emitCPPType(arg.type, os);
361 os << arg.name;
362 });
363 os << ")";
364 if (!method.isStatic())
365 os << " const";
366
367 os << " {\n";
368
369 // Use the empty context for static methods.
370 tblgen::FmtContext ctx;
371 os << tblgen::tgfmt(method.getDefaultImplementation()->trim(),
372 method.isStatic() ? &ctx : &nonStaticMethodFmt);
373 os << "\n}\n";
374 }
375 }
376
emitTraitDecl(Interface & interface,StringRef interfaceName,StringRef interfaceTraitsName)377 void InterfaceGenerator::emitTraitDecl(Interface &interface,
378 StringRef interfaceName,
379 StringRef interfaceTraitsName) {
380 os << llvm::formatv(" template <typename {3}>\n"
381 " struct {0}Trait : public ::mlir::{2}<{0},"
382 " detail::{1}>::Trait<{3}> {{\n",
383 interfaceName, interfaceTraitsName, interfaceBaseType,
384 valueTemplate);
385
386 // Insert the default implementation for any methods.
387 bool isOpInterface = isa<OpInterface>(interface);
388 for (auto &method : interface.getMethods()) {
389 // Flag interface methods named verifyTrait.
390 if (method.getName() == "verifyTrait")
391 PrintFatalError(
392 formatv("'verifyTrait' method cannot be specified as interface "
393 "method for '{0}'; use the 'verify' field instead",
394 interfaceName));
395 auto defaultImpl = method.getDefaultImplementation();
396 if (!defaultImpl)
397 continue;
398
399 os << " " << (method.isStatic() ? "static " : "");
400 emitCPPType(method.getReturnType(), os);
401 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
402 /*addConst=*/!isOpInterface && !method.isStatic());
403 os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
404 << "\n }\n";
405 }
406
407 if (auto verify = interface.getVerify()) {
408 assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
409
410 tblgen::FmtContext verifyCtx;
411 verifyCtx.withOp("op");
412 os << " static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) "
413 "{\n "
414 << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n }\n";
415 }
416 if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
417 os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
418
419 os << " };\n";
420 }
421
emitInterfaceDecl(Interface interface)422 void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
423 llvm::SmallVector<StringRef, 2> namespaces;
424 llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
425 for (StringRef ns : namespaces)
426 os << "namespace " << ns << " {\n";
427
428 StringRef interfaceName = interface.getName();
429 auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
430
431 // Emit a forward declaration of the interface class so that it becomes usable
432 // in the signature of its methods.
433 os << "class " << interfaceName << ";\n";
434
435 // Emit the traits struct containing the concept and model declarations.
436 os << "namespace detail {\n"
437 << "struct " << interfaceTraitsName << " {\n";
438 emitConceptDecl(interface);
439 emitModelDecl(interface);
440 os << "};";
441
442 // Emit the derived trait for the interface.
443 os << "template <typename " << valueTemplate << ">\n";
444 os << "struct " << interface.getName() << "Trait;\n";
445
446 os << "\n} // end namespace detail\n";
447
448 // Emit the main interface class declaration.
449 os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
450 "public:\n"
451 " using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
452 interfaceName, interfaceName, interfaceTraitsName,
453 interfaceBaseType);
454
455 // Emit a utility wrapper trait class.
456 os << llvm::formatv(" template <typename {1}>\n"
457 " struct Trait : public detail::{0}Trait<{1}> {{};\n",
458 interfaceName, valueTemplate);
459
460 // Insert the method declarations.
461 bool isOpInterface = isa<OpInterface>(interface);
462 for (auto &method : interface.getMethods()) {
463 emitCPPType(method.getReturnType(), os << " ");
464 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
465 /*addConst=*/!isOpInterface);
466 os << ";\n";
467 }
468
469 // Emit any extra declarations.
470 if (Optional<StringRef> extraDecls = interface.getExtraClassDeclaration())
471 os << *extraDecls << "\n";
472
473 os << "};\n";
474
475 os << "namespace detail {\n";
476 emitTraitDecl(interface, interfaceName, interfaceTraitsName);
477 os << "}// namespace detail\n";
478
479 emitModelMethodsDef(interface);
480
481 for (StringRef ns : llvm::reverse(namespaces))
482 os << "} // namespace " << ns << "\n";
483 }
484
emitInterfaceDecls()485 bool InterfaceGenerator::emitInterfaceDecls() {
486 llvm::emitSourceFileHeader("Interface Declarations", os);
487
488 for (const auto *def : defs)
489 emitInterfaceDecl(Interface(def));
490 return false;
491 }
492
493 //===----------------------------------------------------------------------===//
494 // GEN: Interface documentation
495 //===----------------------------------------------------------------------===//
496
emitInterfaceDoc(const llvm::Record & interfaceDef,raw_ostream & os)497 static void emitInterfaceDoc(const llvm::Record &interfaceDef,
498 raw_ostream &os) {
499 Interface interface(&interfaceDef);
500
501 // Emit the interface name followed by the description.
502 os << "## " << interface.getName() << " (`" << interfaceDef.getName()
503 << "`)\n\n";
504 if (auto description = interface.getDescription())
505 mlir::tblgen::emitDescription(*description, os);
506
507 // Emit the methods required by the interface.
508 os << "\n### Methods:\n";
509 for (const auto &method : interface.getMethods()) {
510 // Emit the method name.
511 os << "#### `" << method.getName() << "`\n\n```c++\n";
512
513 // Emit the method signature.
514 if (method.isStatic())
515 os << "static ";
516 emitCPPType(method.getReturnType(), os) << method.getName() << '(';
517 llvm::interleaveComma(method.getArguments(), os,
518 [&](const InterfaceMethod::Argument &arg) {
519 emitCPPType(arg.type, os) << arg.name;
520 });
521 os << ");\n```\n";
522
523 // Emit the description.
524 if (auto description = method.getDescription())
525 mlir::tblgen::emitDescription(*description, os);
526
527 // If the body is not provided, this method must be provided by the user.
528 if (!method.getBody())
529 os << "\nNOTE: This method *must* be implemented by the user.\n\n";
530 }
531 }
532
emitInterfaceDocs()533 bool InterfaceGenerator::emitInterfaceDocs() {
534 os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
535 os << "# " << interfaceBaseType << " definitions\n";
536
537 for (const auto *def : defs)
538 emitInterfaceDoc(*def, os);
539 return false;
540 }
541
542 //===----------------------------------------------------------------------===//
543 // GEN: Interface registration hooks
544 //===----------------------------------------------------------------------===//
545
546 namespace {
547 template <typename GeneratorT>
548 struct InterfaceGenRegistration {
InterfaceGenRegistration__anonef0631450c11::InterfaceGenRegistration549 InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
550 : genDeclArg(("gen-" + genArg + "-interface-decls").str()),
551 genDefArg(("gen-" + genArg + "-interface-defs").str()),
552 genDocArg(("gen-" + genArg + "-interface-docs").str()),
553 genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
554 genDefDesc(("Generate " + genDesc + " interface definitions").str()),
555 genDocDesc(("Generate " + genDesc + " interface documentation").str()),
556 genDecls(genDeclArg, genDeclDesc,
557 [](const llvm::RecordKeeper &records, raw_ostream &os) {
558 return GeneratorT(records, os).emitInterfaceDecls();
559 }),
560 genDefs(genDefArg, genDefDesc,
__anonef0631450e02__anonef0631450c11::InterfaceGenRegistration561 [](const llvm::RecordKeeper &records, raw_ostream &os) {
562 return GeneratorT(records, os).emitInterfaceDefs();
563 }),
564 genDocs(genDocArg, genDocDesc,
__anonef0631450f02__anonef0631450c11::InterfaceGenRegistration565 [](const llvm::RecordKeeper &records, raw_ostream &os) {
566 return GeneratorT(records, os).emitInterfaceDocs();
567 }) {}
568
569 std::string genDeclArg, genDefArg, genDocArg;
570 std::string genDeclDesc, genDefDesc, genDocDesc;
571 mlir::GenRegistration genDecls, genDefs, genDocs;
572 };
573 } // end anonymous namespace
574
575 static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
576 "attribute");
577 static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
578 static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");
579