1 //===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // TypeDefGen uses the description of typeDefs to generate C++ definitions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/LogicalResult.h"
14 #include "mlir/TableGen/CodeGenHelpers.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/TypeDef.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/TableGen/Error.h"
21 #include "llvm/TableGen/TableGenBackend.h"
22 
23 #define DEBUG_TYPE "mlir-tblgen-typedefgen"
24 
25 using namespace mlir;
26 using namespace mlir::tblgen;
27 
28 static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
29 static llvm::cl::opt<std::string>
30     selectedDialect("typedefs-dialect",
31                     llvm::cl::desc("Gen types for this dialect"),
32                     llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
33 
34 /// Find all the TypeDefs for the specified dialect. If no dialect specified and
35 /// can only find one dialect's types, use that.
findAllTypeDefs(const llvm::RecordKeeper & recordKeeper,SmallVectorImpl<TypeDef> & typeDefs)36 static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper,
37                             SmallVectorImpl<TypeDef> &typeDefs) {
38   auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
39   auto defs = llvm::map_range(
40       recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); });
41   if (defs.empty())
42     return;
43 
44   StringRef dialectName;
45   if (selectedDialect.getNumOccurrences() == 0) {
46     if (defs.empty())
47       return;
48 
49     llvm::SmallSet<Dialect, 4> dialects;
50     for (const TypeDef typeDef : defs)
51       dialects.insert(typeDef.getDialect());
52     if (dialects.size() != 1)
53       llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must "
54                             "select one via '--typedefs-dialect'");
55 
56     dialectName = (*dialects.begin()).getName();
57   } else if (selectedDialect.getNumOccurrences() == 1) {
58     dialectName = selectedDialect.getValue();
59   } else {
60     llvm::PrintFatalError("Cannot select multiple dialects for which to "
61                           "generate types via '--typedefs-dialect'.");
62   }
63 
64   for (const TypeDef typeDef : defs)
65     if (typeDef.getDialect().getName().equals(dialectName))
66       typeDefs.push_back(typeDef);
67 }
68 
69 namespace {
70 
71 /// Pass an instance of this class to llvm::formatv() to emit a comma separated
72 /// list of parameters in the format by 'EmitFormat'.
73 class TypeParamCommaFormatter : public llvm::detail::format_adapter {
74 public:
75   /// Choose the output format
76   enum EmitFormat {
77     /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
78     /// [...]".
79     TypeNamePairs,
80 
81     /// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
82     TypeNameInitializer,
83 
84     /// Emit "param1Name, param2Name, [...]".
85     JustParams,
86   };
87 
TypeParamCommaFormatter(EmitFormat emitFormat,ArrayRef<TypeParameter> params,bool prependComma=true)88   TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params,
89                           bool prependComma = true)
90       : emitFormat(emitFormat), params(params), prependComma(prependComma) {}
91 
92   /// llvm::formatv will call this function when using an instance as a
93   /// replacement value.
format(raw_ostream & os,StringRef options)94   void format(raw_ostream &os, StringRef options) override {
95     if (!params.empty() && prependComma)
96       os << ", ";
97 
98     switch (emitFormat) {
99     case EmitFormat::TypeNamePairs:
100       interleaveComma(params, os,
101                       [&](const TypeParameter &p) { emitTypeNamePair(p, os); });
102       break;
103     case EmitFormat::TypeNameInitializer:
104       interleaveComma(params, os, [&](const TypeParameter &p) {
105         emitTypeNameInitializer(p, os);
106       });
107       break;
108     case EmitFormat::JustParams:
109       interleaveComma(params, os,
110                       [&](const TypeParameter &p) { os << p.getName(); });
111       break;
112     }
113   }
114 
115 private:
116   // Emit "paramType paramName".
emitTypeNamePair(const TypeParameter & param,raw_ostream & os)117   static void emitTypeNamePair(const TypeParameter &param, raw_ostream &os) {
118     os << param.getCppType() << " " << param.getName();
119   }
120   // Emit "paramName(paramName)"
emitTypeNameInitializer(const TypeParameter & param,raw_ostream & os)121   void emitTypeNameInitializer(const TypeParameter &param, raw_ostream &os) {
122     os << param.getName() << "(" << param.getName() << ")";
123   }
124 
125   EmitFormat emitFormat;
126   ArrayRef<TypeParameter> params;
127   bool prependComma;
128 };
129 
130 } // end anonymous namespace
131 
132 //===----------------------------------------------------------------------===//
133 // GEN: TypeDef declarations
134 //===----------------------------------------------------------------------===//
135 
136 /// Print this above all the other declarations. Contains type declarations used
137 /// later on.
138 static const char *const typeDefDeclHeader = R"(
139 namespace mlir {
140 class DialectAsmParser;
141 class DialectAsmPrinter;
142 } // namespace mlir
143 )";
144 
145 /// The code block for the start of a typeDef class declaration -- singleton
146 /// case.
147 ///
148 /// {0}: The name of the typeDef class.
149 /// {1}: The name of the type base class.
150 static const char *const typeDefDeclSingletonBeginStr = R"(
151   class {0} : public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
152   public:
153     /// Inherit some necessary constructors from 'TypeBase'.
154     using Base::Base;
155 
156 )";
157 
158 /// The code block for the start of a typeDef class declaration -- parametric
159 /// case.
160 ///
161 /// {0}: The name of the typeDef class.
162 /// {1}: The name of the type base class.
163 /// {2}: The typeDef storage class namespace.
164 /// {3}: The storage class name.
165 /// {4}: The list of parameters with types.
166 static const char *const typeDefDeclParametricBeginStr = R"(
167   namespace {2} {
168     struct {3};
169   } // end namespace {2}
170   class {0} : public ::mlir::Type::TypeBase<{0}, {1},
171                                          {2}::{3}> {{
172   public:
173     /// Inherit some necessary constructors from 'TypeBase'.
174     using Base::Base;
175 
176 )";
177 
178 /// The snippet for print/parse.
179 static const char *const typeDefParsePrint = R"(
180     static ::mlir::Type parse(::mlir::MLIRContext *context,
181                               ::mlir::DialectAsmParser &parser);
182     void print(::mlir::DialectAsmPrinter &printer) const;
183 )";
184 
185 /// The code block for the verifyConstructionInvariants and getChecked.
186 ///
187 /// {0}: The name of the typeDef class.
188 /// {1}: List of parameters, parameters style.
189 static const char *const typeDefDeclVerifyStr = R"(
190     static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1});
191 )";
192 
193 /// Emit the builders for the given type.
emitTypeBuilderDecls(const TypeDef & typeDef,raw_ostream & os,TypeParamCommaFormatter & paramTypes)194 static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
195                                  TypeParamCommaFormatter &paramTypes) {
196   StringRef typeClass = typeDef.getCppClassName();
197   bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
198   if (!typeDef.skipDefaultBuilders()) {
199     os << llvm::formatv(
200         "    static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
201         paramTypes);
202     if (genCheckedMethods) {
203       os << llvm::formatv(
204           "    static {0} getChecked(::mlir::Location loc{1});\n", typeClass,
205           paramTypes);
206     }
207   }
208 
209   // Generate the builders specified by the user.
210   for (const TypeBuilder &builder : typeDef.getBuilders()) {
211     std::string paramStr;
212     llvm::raw_string_ostream paramOS(paramStr);
213     llvm::interleaveComma(
214         builder.getParameters(), paramOS,
215         [&](const TypeBuilder::Parameter &param) {
216           // Note: TypeBuilder parameters are guaranteed to have names.
217           paramOS << param.getCppType() << " " << *param.getName();
218           if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
219             paramOS << " = " << *defaultParamValue;
220         });
221     paramOS.flush();
222 
223     // Generate the `get` variant of the builder.
224     os << "    static " << typeClass << " get(";
225     if (!builder.hasInferredContextParameter()) {
226       os << "::mlir::MLIRContext *context";
227       if (!paramStr.empty())
228         os << ", ";
229     }
230     os << paramStr << ");\n";
231 
232     // Generate the `getChecked` variant of the builder.
233     if (genCheckedMethods) {
234       os << "    static " << typeClass << " getChecked(::mlir::Location loc";
235       if (!paramStr.empty())
236         os << ", " << paramStr;
237       os << ");\n";
238     }
239   }
240 }
241 
242 /// Generate the declaration for the given typeDef class.
emitTypeDefDecl(const TypeDef & typeDef,raw_ostream & os)243 static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
244   SmallVector<TypeParameter, 4> params;
245   typeDef.getParameters(params);
246 
247   // Emit the beginning string template: either the singleton or parametric
248   // template.
249   if (typeDef.getNumParameters() == 0)
250     os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
251                   typeDef.getCppBaseClassName());
252   else
253     os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
254                   typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(),
255                   typeDef.getStorageClassName());
256 
257   // Emit the extra declarations first in case there's a type definition in
258   // there.
259   if (Optional<StringRef> extraDecl = typeDef.getExtraDecls())
260     os << *extraDecl << "\n";
261 
262   TypeParamCommaFormatter emitTypeNamePairsAfterComma(
263       TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
264   if (!params.empty()) {
265     emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);
266 
267     // Emit the verify invariants declaration.
268     if (typeDef.genVerifyInvariantsDecl())
269       os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(),
270                           emitTypeNamePairsAfterComma);
271   }
272 
273   // Emit the mnenomic, if specified.
274   if (auto mnenomic = typeDef.getMnemonic()) {
275     os << "    static ::llvm::StringRef getMnemonic() { return \"" << mnenomic
276        << "\"; }\n";
277 
278     // If mnemonic specified, emit print/parse declarations.
279     if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty())
280       os << typeDefParsePrint;
281   }
282 
283   if (typeDef.genAccessors()) {
284     SmallVector<TypeParameter, 4> parameters;
285     typeDef.getParameters(parameters);
286 
287     for (TypeParameter &parameter : parameters) {
288       SmallString<16> name = parameter.getName();
289       name[0] = llvm::toUpper(name[0]);
290       os << formatv("    {0} get{1}() const;\n", parameter.getCppType(), name);
291     }
292   }
293 
294   // End the typeDef decl.
295   os << "  };\n";
296 }
297 
298 /// Main entry point for decls.
emitTypeDefDecls(const llvm::RecordKeeper & recordKeeper,raw_ostream & os)299 static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
300                              raw_ostream &os) {
301   emitSourceFileHeader("TypeDef Declarations", os);
302 
303   SmallVector<TypeDef, 16> typeDefs;
304   findAllTypeDefs(recordKeeper, typeDefs);
305 
306   IfDefScope scope("GET_TYPEDEF_CLASSES", os);
307 
308   // Output the common "header".
309   os << typeDefDeclHeader;
310 
311   if (!typeDefs.empty()) {
312     NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
313 
314     // Declare all the type classes first (in case they reference each other).
315     for (const TypeDef &typeDef : typeDefs)
316       os << "  class " << typeDef.getCppClassName() << ";\n";
317 
318     // Declare all the typedefs.
319     for (const TypeDef &typeDef : typeDefs)
320       emitTypeDefDecl(typeDef, os);
321   }
322 
323   return false;
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // GEN: TypeDef list
328 //===----------------------------------------------------------------------===//
329 
emitTypeDefList(SmallVectorImpl<TypeDef> & typeDefs,raw_ostream & os)330 static void emitTypeDefList(SmallVectorImpl<TypeDef> &typeDefs,
331                             raw_ostream &os) {
332   IfDefScope scope("GET_TYPEDEF_LIST", os);
333   for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) {
334     os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName();
335     if (i < typeDefs.end() - 1)
336       os << ",\n";
337     else
338       os << "\n";
339   }
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // GEN: TypeDef definitions
344 //===----------------------------------------------------------------------===//
345 
346 /// Beginning of storage class.
347 /// {0}: Storage class namespace.
348 /// {1}: Storage class c++ name.
349 /// {2}: Parameters parameters.
350 /// {3}: Parameter initializer string.
351 /// {4}: Parameter name list.
352 /// {5}: Parameter types.
353 static const char *const typeDefStorageClassBegin = R"(
354 namespace {0} {{
355   struct {1} : public ::mlir::TypeStorage {{
356     {1} ({2})
357       : {3} {{ }
358 
359     /// The hash key for this storage is a pair of the integer and type params.
360     using KeyTy = std::tuple<{5}>;
361 
362     /// Define the comparison function for the key type.
363     bool operator==(const KeyTy &key) const {{
364       return key == KeyTy({4});
365     }
366 )";
367 
368 /// The storage class' constructor template.
369 /// {0}: storage class name.
370 static const char *const typeDefStorageClassConstructorBegin = R"(
371     /// Define a construction method for creating a new instance of this storage.
372     static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{
373 )";
374 
375 /// The storage class' constructor return template.
376 /// {0}: storage class name.
377 /// {1}: list of parameters.
378 static const char *const typeDefStorageClassConstructorReturn = R"(
379       return new (allocator.allocate<{0}>())
380           {0}({1});
381     }
382 )";
383 
384 /// Use tgfmt to emit custom allocation code for each parameter, if necessary.
emitParameterAllocationCode(TypeDef & typeDef,raw_ostream & os)385 static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
386   SmallVector<TypeParameter, 4> parameters;
387   typeDef.getParameters(parameters);
388   auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
389   for (TypeParameter &parameter : parameters) {
390     auto allocCode = parameter.getAllocator();
391     if (allocCode) {
392       fmtCtxt.withSelf(parameter.getName());
393       fmtCtxt.addSubst("_dst", parameter.getName());
394       os << "      " << tgfmt(*allocCode, &fmtCtxt) << "\n";
395     }
396   }
397 }
398 
399 /// Emit the storage class code for type 'typeDef'.
400 /// This includes (in-order):
401 ///  1) typeDefStorageClassBegin, which includes:
402 ///      - The class constructor.
403 ///      - The KeyTy definition.
404 ///      - The equality (==) operator.
405 ///  2) The hashKey method.
406 ///  3) The construct method.
407 ///  4) The list of parameters as the storage class member variables.
emitStorageClass(TypeDef typeDef,raw_ostream & os)408 static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
409   SmallVector<TypeParameter, 4> parameters;
410   typeDef.getParameters(parameters);
411 
412   // Initialize a bunch of variables to be used later on.
413   auto parameterNames = map_range(
414       parameters, [](TypeParameter parameter) { return parameter.getName(); });
415   auto parameterTypes = map_range(parameters, [](TypeParameter parameter) {
416     return parameter.getCppType();
417   });
418   auto parameterList = join(parameterNames, ", ");
419   auto parameterTypeList = join(parameterTypes, ", ");
420 
421   // 1) Emit most of the storage class up until the hashKey body.
422   os << formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(),
423                 typeDef.getStorageClassName(),
424                 TypeParamCommaFormatter(
425                     TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
426                     parameters, /*prependComma=*/false),
427                 TypeParamCommaFormatter(
428                     TypeParamCommaFormatter::EmitFormat::TypeNameInitializer,
429                     parameters, /*prependComma=*/false),
430                 parameterList, parameterTypeList);
431 
432   // 2) Emit the haskKey method.
433   os << "  static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
434   // Extract each parameter from the key.
435   for (size_t i = 0, e = parameters.size(); i < e; ++i)
436     os << llvm::formatv("      const auto &{0} = std::get<{1}>(key);\n",
437                         parameters[i].getName(), i);
438   // Then combine them all. This requires all the parameters types to have a
439   // hash_value defined.
440   os << llvm::formatv(
441       "      return ::llvm::hash_combine({0});\n    }\n",
442       TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
443                               parameters, /* prependComma */ false));
444 
445   // 3) Emit the construct method.
446   if (typeDef.hasStorageCustomConstructor()) {
447     // If user wants to build the storage constructor themselves, declare it
448     // here and then they can write the definition elsewhere.
449     os << "    static " << typeDef.getStorageClassName()
450        << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
451           "&key);\n";
452   } else {
453     // If not, autogenerate one.
454 
455     // First, unbox the parameters.
456     os << formatv(typeDefStorageClassConstructorBegin,
457                   typeDef.getStorageClassName());
458     for (size_t i = 0; i < parameters.size(); ++i) {
459       os << formatv("      auto {0} = std::get<{1}>(key);\n",
460                     parameters[i].getName(), i);
461     }
462     // Second, reassign the parameter variables with allocation code, if it's
463     // specified.
464     emitParameterAllocationCode(typeDef, os);
465 
466     // Last, return an allocated copy.
467     os << formatv(typeDefStorageClassConstructorReturn,
468                   typeDef.getStorageClassName(), parameterList);
469   }
470 
471   // 4) Emit the parameters as storage class members.
472   for (auto parameter : parameters) {
473     os << "      " << parameter.getCppType() << " " << parameter.getName()
474        << ";\n";
475   }
476   os << "  };\n";
477 
478   os << "} // namespace " << typeDef.getStorageNamespace() << "\n";
479 }
480 
481 /// Emit the parser and printer for a particular type, if they're specified.
emitParserPrinter(TypeDef typeDef,raw_ostream & os)482 void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
483   // Emit the printer code, if specified.
484   if (auto printerCode = typeDef.getPrinterCode()) {
485     // Both the mnenomic and printerCode must be defined (for parity with
486     // parserCode).
487     os << "void " << typeDef.getCppClassName()
488        << "::print(::mlir::DialectAsmPrinter &printer) const {\n";
489     if (*printerCode == "") {
490       // If no code specified, emit error.
491       PrintFatalError(typeDef.getLoc(),
492                       typeDef.getName() +
493                           ": printer (if specified) must have non-empty code");
494     }
495     auto fmtCtxt = FmtContext().addSubst("_printer", "printer");
496     os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
497   }
498 
499   // emit a parser, if specified.
500   if (auto parserCode = typeDef.getParserCode()) {
501     // The mnenomic must be defined so the dispatcher knows how to dispatch.
502     os << "::mlir::Type " << typeDef.getCppClassName()
503        << "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &"
504           "parser) "
505           "{\n";
506     if (*parserCode == "") {
507       // if no code specified, emit error.
508       PrintFatalError(typeDef.getLoc(),
509                       typeDef.getName() +
510                           ": parser (if specified) must have non-empty code");
511     }
512     auto fmtCtxt =
513         FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context");
514     os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
515   }
516 }
517 
518 /// Emit the builders for the given type.
emitTypeBuilderDefs(const TypeDef & typeDef,raw_ostream & os,ArrayRef<TypeParameter> typeDefParams)519 static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
520                                 ArrayRef<TypeParameter> typeDefParams) {
521   bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
522   StringRef typeClass = typeDef.getCppClassName();
523   if (!typeDef.skipDefaultBuilders()) {
524     os << llvm::formatv(
525         "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
526         "  return Base::get(context{2});\n}\n",
527         typeClass,
528         TypeParamCommaFormatter(
529             TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams),
530         TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
531                                 typeDefParams));
532     if (genCheckedMethods) {
533       os << llvm::formatv(
534           "{0} {0}::getChecked(::mlir::Location loc{1}) {{\n"
535           "  return Base::getChecked(loc{2});\n}\n",
536           typeClass,
537           TypeParamCommaFormatter(
538               TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
539               typeDefParams),
540           TypeParamCommaFormatter(
541               TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams));
542     }
543   }
544 
545   // Generate the builders specified by the user.
546   auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context");
547   auto checkedBuilderFmtCtx = FmtContext()
548                                   .addSubst("_loc", "loc")
549                                   .addSubst("_ctxt", "loc.getContext()");
550   for (const TypeBuilder &builder : typeDef.getBuilders()) {
551     Optional<StringRef> body = builder.getBody();
552     Optional<StringRef> checkedBody =
553         genCheckedMethods ? builder.getCheckedBody() : llvm::None;
554     if (!body && !checkedBody)
555       continue;
556     std::string paramStr;
557     llvm::raw_string_ostream paramOS(paramStr);
558     llvm::interleaveComma(builder.getParameters(), paramOS,
559                           [&](const TypeBuilder::Parameter &param) {
560                             // Note: TypeBuilder parameters are guaranteed to
561                             // have names.
562                             paramOS << param.getCppType() << " "
563                                     << *param.getName();
564                           });
565     paramOS.flush();
566 
567     // Emit the `get` variant of the builder.
568     if (body) {
569       os << llvm::formatv("{0} {0}::get(", typeClass);
570       if (!builder.hasInferredContextParameter()) {
571         os << "::mlir::MLIRContext *context";
572         if (!paramStr.empty())
573           os << ", ";
574         os << llvm::formatv("{0}) {{\n  {1};\n}\n", paramStr,
575                             tgfmt(*body, &builderFmtCtx).str());
576       } else {
577         os << llvm::formatv("{0}) {{\n  {1};\n}\n", paramStr, *body);
578       }
579     }
580 
581     // Emit the `getChecked` variant of the builder.
582     if (checkedBody) {
583       os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc",
584                           typeClass);
585       if (!paramStr.empty())
586         os << ", " << paramStr;
587       os << llvm::formatv(") {{\n  {0};\n}\n",
588                           tgfmt(*checkedBody, &checkedBuilderFmtCtx));
589     }
590   }
591 }
592 
593 /// Print all the typedef-specific definition code.
emitTypeDefDef(const TypeDef & typeDef,raw_ostream & os)594 static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) {
595   NamespaceEmitter ns(os, typeDef.getDialect());
596 
597   SmallVector<TypeParameter, 4> parameters;
598   typeDef.getParameters(parameters);
599   if (!parameters.empty()) {
600     // Emit the storage class, if requested and necessary.
601     if (typeDef.genStorageClass())
602       emitStorageClass(typeDef, os);
603 
604     // Emit the builders for this type.
605     emitTypeBuilderDefs(typeDef, os, parameters);
606 
607     // Generate accessor definitions only if we also generate the storage class.
608     // Otherwise, let the user define the exact accessor definition.
609     if (typeDef.genAccessors() && typeDef.genStorageClass()) {
610       // Emit the parameter accessors.
611       for (const TypeParameter &parameter : parameters) {
612         SmallString<16> name = parameter.getName();
613         name[0] = llvm::toUpper(name[0]);
614         os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
615                       parameter.getCppType(), name, parameter.getName(),
616                       typeDef.getCppClassName());
617       }
618     }
619   }
620 
621   // If mnemonic is specified maybe print definitions for the parser and printer
622   // code, if they're specified.
623   if (typeDef.getMnemonic())
624     emitParserPrinter(typeDef, os);
625 }
626 
627 /// Emit the dialect printer/parser dispatcher. User's code should call these
628 /// functions from their dialect's print/parse methods.
emitParsePrintDispatch(ArrayRef<TypeDef> types,raw_ostream & os)629 static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
630   if (llvm::none_of(types, [](const TypeDef &type) {
631         return type.getMnemonic().hasValue();
632       })) {
633     return;
634   }
635 
636   // The parser dispatch is just a list of if-elses, matching on the
637   // mnemonic and calling the class's parse function.
638   os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *"
639         "context, ::mlir::DialectAsmParser &parser, "
640         "::llvm::StringRef mnemonic) {\n";
641   for (const TypeDef &type : types) {
642     if (type.getMnemonic()) {
643       os << formatv("  if (mnemonic == {0}::{1}::getMnemonic()) return "
644                     "{0}::{1}::",
645                     type.getDialect().getCppNamespace(),
646                     type.getCppClassName());
647 
648       // If the type has no parameters and no parser code, just invoke a normal
649       // `get`.
650       if (type.getNumParameters() == 0 && !type.getParserCode())
651         os << "get(context);\n";
652       else
653         os << "parse(context, parser);\n";
654     }
655   }
656   os << "  return ::mlir::Type();\n";
657   os << "}\n\n";
658 
659   // The printer dispatch uses llvm::TypeSwitch to find and call the correct
660   // printer.
661   os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
662         "type, "
663         "::mlir::DialectAsmPrinter &printer) {\n"
664      << "  return ::llvm::TypeSwitch<::mlir::Type, "
665         "::mlir::LogicalResult>(type)\n";
666   for (const TypeDef &type : types) {
667     if (Optional<StringRef> mnemonic = type.getMnemonic()) {
668       StringRef cppNamespace = type.getDialect().getCppNamespace();
669       StringRef cppClassName = type.getCppClassName();
670       os << formatv("    .Case<{0}::{1}>([&]({0}::{1} t) {{\n      ",
671                     cppNamespace, cppClassName);
672 
673       // If the type has no parameters and no printer code, just print the
674       // mnemonic.
675       if (type.getNumParameters() == 0 && !type.getPrinterCode())
676         os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace,
677                       cppClassName);
678       else
679         os << "t.print(printer);";
680       os << "\n      return ::mlir::success();\n    })\n";
681     }
682   }
683   os << "    .Default([](::mlir::Type) { return ::mlir::failure(); });\n"
684      << "}\n\n";
685 }
686 
687 /// Entry point for typedef definitions.
emitTypeDefDefs(const llvm::RecordKeeper & recordKeeper,raw_ostream & os)688 static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
689                             raw_ostream &os) {
690   emitSourceFileHeader("TypeDef Definitions", os);
691 
692   SmallVector<TypeDef, 16> typeDefs;
693   findAllTypeDefs(recordKeeper, typeDefs);
694   emitTypeDefList(typeDefs, os);
695 
696   IfDefScope scope("GET_TYPEDEF_CLASSES", os);
697   emitParsePrintDispatch(typeDefs, os);
698   for (const TypeDef &typeDef : typeDefs)
699     emitTypeDefDef(typeDef, os);
700 
701   return false;
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // GEN: TypeDef registration hooks
706 //===----------------------------------------------------------------------===//
707 
708 static mlir::GenRegistration
709     genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions",
__anond31b88460b02(const llvm::RecordKeeper &records, raw_ostream &os) 710                    [](const llvm::RecordKeeper &records, raw_ostream &os) {
711                      return emitTypeDefDefs(records, os);
712                    });
713 
714 static mlir::GenRegistration
715     genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations",
__anond31b88460c02(const llvm::RecordKeeper &records, raw_ostream &os) 716                     [](const llvm::RecordKeeper &records, raw_ostream &os) {
717                       return emitTypeDefDecls(records, os);
718                     });
719