1 //===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // TypeDefGen uses the description of typeDefs to generate C++ definitions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Support/LogicalResult.h"
14 #include "mlir/TableGen/CodeGenHelpers.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/TypeDef.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/TableGen/Error.h"
21 #include "llvm/TableGen/TableGenBackend.h"
22
23 #define DEBUG_TYPE "mlir-tblgen-typedefgen"
24
25 using namespace mlir;
26 using namespace mlir::tblgen;
27
28 static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
29 static llvm::cl::opt<std::string>
30 selectedDialect("typedefs-dialect",
31 llvm::cl::desc("Gen types for this dialect"),
32 llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
33
34 /// Find all the TypeDefs for the specified dialect. If no dialect specified and
35 /// can only find one dialect's types, use that.
findAllTypeDefs(const llvm::RecordKeeper & recordKeeper,SmallVectorImpl<TypeDef> & typeDefs)36 static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper,
37 SmallVectorImpl<TypeDef> &typeDefs) {
38 auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
39 auto defs = llvm::map_range(
40 recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); });
41 if (defs.empty())
42 return;
43
44 StringRef dialectName;
45 if (selectedDialect.getNumOccurrences() == 0) {
46 if (defs.empty())
47 return;
48
49 llvm::SmallSet<Dialect, 4> dialects;
50 for (const TypeDef typeDef : defs)
51 dialects.insert(typeDef.getDialect());
52 if (dialects.size() != 1)
53 llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must "
54 "select one via '--typedefs-dialect'");
55
56 dialectName = (*dialects.begin()).getName();
57 } else if (selectedDialect.getNumOccurrences() == 1) {
58 dialectName = selectedDialect.getValue();
59 } else {
60 llvm::PrintFatalError("Cannot select multiple dialects for which to "
61 "generate types via '--typedefs-dialect'.");
62 }
63
64 for (const TypeDef typeDef : defs)
65 if (typeDef.getDialect().getName().equals(dialectName))
66 typeDefs.push_back(typeDef);
67 }
68
69 namespace {
70
71 /// Pass an instance of this class to llvm::formatv() to emit a comma separated
72 /// list of parameters in the format by 'EmitFormat'.
73 class TypeParamCommaFormatter : public llvm::detail::format_adapter {
74 public:
75 /// Choose the output format
76 enum EmitFormat {
77 /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
78 /// [...]".
79 TypeNamePairs,
80
81 /// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
82 TypeNameInitializer,
83
84 /// Emit "param1Name, param2Name, [...]".
85 JustParams,
86 };
87
TypeParamCommaFormatter(EmitFormat emitFormat,ArrayRef<TypeParameter> params,bool prependComma=true)88 TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params,
89 bool prependComma = true)
90 : emitFormat(emitFormat), params(params), prependComma(prependComma) {}
91
92 /// llvm::formatv will call this function when using an instance as a
93 /// replacement value.
format(raw_ostream & os,StringRef options)94 void format(raw_ostream &os, StringRef options) override {
95 if (!params.empty() && prependComma)
96 os << ", ";
97
98 switch (emitFormat) {
99 case EmitFormat::TypeNamePairs:
100 interleaveComma(params, os,
101 [&](const TypeParameter &p) { emitTypeNamePair(p, os); });
102 break;
103 case EmitFormat::TypeNameInitializer:
104 interleaveComma(params, os, [&](const TypeParameter &p) {
105 emitTypeNameInitializer(p, os);
106 });
107 break;
108 case EmitFormat::JustParams:
109 interleaveComma(params, os,
110 [&](const TypeParameter &p) { os << p.getName(); });
111 break;
112 }
113 }
114
115 private:
116 // Emit "paramType paramName".
emitTypeNamePair(const TypeParameter & param,raw_ostream & os)117 static void emitTypeNamePair(const TypeParameter ¶m, raw_ostream &os) {
118 os << param.getCppType() << " " << param.getName();
119 }
120 // Emit "paramName(paramName)"
emitTypeNameInitializer(const TypeParameter & param,raw_ostream & os)121 void emitTypeNameInitializer(const TypeParameter ¶m, raw_ostream &os) {
122 os << param.getName() << "(" << param.getName() << ")";
123 }
124
125 EmitFormat emitFormat;
126 ArrayRef<TypeParameter> params;
127 bool prependComma;
128 };
129
130 } // end anonymous namespace
131
132 //===----------------------------------------------------------------------===//
133 // GEN: TypeDef declarations
134 //===----------------------------------------------------------------------===//
135
136 /// Print this above all the other declarations. Contains type declarations used
137 /// later on.
138 static const char *const typeDefDeclHeader = R"(
139 namespace mlir {
140 class DialectAsmParser;
141 class DialectAsmPrinter;
142 } // namespace mlir
143 )";
144
145 /// The code block for the start of a typeDef class declaration -- singleton
146 /// case.
147 ///
148 /// {0}: The name of the typeDef class.
149 /// {1}: The name of the type base class.
150 static const char *const typeDefDeclSingletonBeginStr = R"(
151 class {0} : public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
152 public:
153 /// Inherit some necessary constructors from 'TypeBase'.
154 using Base::Base;
155
156 )";
157
158 /// The code block for the start of a typeDef class declaration -- parametric
159 /// case.
160 ///
161 /// {0}: The name of the typeDef class.
162 /// {1}: The name of the type base class.
163 /// {2}: The typeDef storage class namespace.
164 /// {3}: The storage class name.
165 /// {4}: The list of parameters with types.
166 static const char *const typeDefDeclParametricBeginStr = R"(
167 namespace {2} {
168 struct {3};
169 } // end namespace {2}
170 class {0} : public ::mlir::Type::TypeBase<{0}, {1},
171 {2}::{3}> {{
172 public:
173 /// Inherit some necessary constructors from 'TypeBase'.
174 using Base::Base;
175
176 )";
177
178 /// The snippet for print/parse.
179 static const char *const typeDefParsePrint = R"(
180 static ::mlir::Type parse(::mlir::MLIRContext *context,
181 ::mlir::DialectAsmParser &parser);
182 void print(::mlir::DialectAsmPrinter &printer) const;
183 )";
184
185 /// The code block for the verifyConstructionInvariants and getChecked.
186 ///
187 /// {0}: The name of the typeDef class.
188 /// {1}: List of parameters, parameters style.
189 static const char *const typeDefDeclVerifyStr = R"(
190 static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1});
191 )";
192
193 /// Emit the builders for the given type.
emitTypeBuilderDecls(const TypeDef & typeDef,raw_ostream & os,TypeParamCommaFormatter & paramTypes)194 static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
195 TypeParamCommaFormatter ¶mTypes) {
196 StringRef typeClass = typeDef.getCppClassName();
197 bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
198 if (!typeDef.skipDefaultBuilders()) {
199 os << llvm::formatv(
200 " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
201 paramTypes);
202 if (genCheckedMethods) {
203 os << llvm::formatv(
204 " static {0} getChecked(::mlir::Location loc{1});\n", typeClass,
205 paramTypes);
206 }
207 }
208
209 // Generate the builders specified by the user.
210 for (const TypeBuilder &builder : typeDef.getBuilders()) {
211 std::string paramStr;
212 llvm::raw_string_ostream paramOS(paramStr);
213 llvm::interleaveComma(
214 builder.getParameters(), paramOS,
215 [&](const TypeBuilder::Parameter ¶m) {
216 // Note: TypeBuilder parameters are guaranteed to have names.
217 paramOS << param.getCppType() << " " << *param.getName();
218 if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
219 paramOS << " = " << *defaultParamValue;
220 });
221 paramOS.flush();
222
223 // Generate the `get` variant of the builder.
224 os << " static " << typeClass << " get(";
225 if (!builder.hasInferredContextParameter()) {
226 os << "::mlir::MLIRContext *context";
227 if (!paramStr.empty())
228 os << ", ";
229 }
230 os << paramStr << ");\n";
231
232 // Generate the `getChecked` variant of the builder.
233 if (genCheckedMethods) {
234 os << " static " << typeClass << " getChecked(::mlir::Location loc";
235 if (!paramStr.empty())
236 os << ", " << paramStr;
237 os << ");\n";
238 }
239 }
240 }
241
242 /// Generate the declaration for the given typeDef class.
emitTypeDefDecl(const TypeDef & typeDef,raw_ostream & os)243 static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
244 SmallVector<TypeParameter, 4> params;
245 typeDef.getParameters(params);
246
247 // Emit the beginning string template: either the singleton or parametric
248 // template.
249 if (typeDef.getNumParameters() == 0)
250 os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
251 typeDef.getCppBaseClassName());
252 else
253 os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
254 typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(),
255 typeDef.getStorageClassName());
256
257 // Emit the extra declarations first in case there's a type definition in
258 // there.
259 if (Optional<StringRef> extraDecl = typeDef.getExtraDecls())
260 os << *extraDecl << "\n";
261
262 TypeParamCommaFormatter emitTypeNamePairsAfterComma(
263 TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
264 if (!params.empty()) {
265 emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);
266
267 // Emit the verify invariants declaration.
268 if (typeDef.genVerifyInvariantsDecl())
269 os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(),
270 emitTypeNamePairsAfterComma);
271 }
272
273 // Emit the mnenomic, if specified.
274 if (auto mnenomic = typeDef.getMnemonic()) {
275 os << " static ::llvm::StringRef getMnemonic() { return \"" << mnenomic
276 << "\"; }\n";
277
278 // If mnemonic specified, emit print/parse declarations.
279 if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty())
280 os << typeDefParsePrint;
281 }
282
283 if (typeDef.genAccessors()) {
284 SmallVector<TypeParameter, 4> parameters;
285 typeDef.getParameters(parameters);
286
287 for (TypeParameter ¶meter : parameters) {
288 SmallString<16> name = parameter.getName();
289 name[0] = llvm::toUpper(name[0]);
290 os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name);
291 }
292 }
293
294 // End the typeDef decl.
295 os << " };\n";
296 }
297
298 /// Main entry point for decls.
emitTypeDefDecls(const llvm::RecordKeeper & recordKeeper,raw_ostream & os)299 static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
300 raw_ostream &os) {
301 emitSourceFileHeader("TypeDef Declarations", os);
302
303 SmallVector<TypeDef, 16> typeDefs;
304 findAllTypeDefs(recordKeeper, typeDefs);
305
306 IfDefScope scope("GET_TYPEDEF_CLASSES", os);
307
308 // Output the common "header".
309 os << typeDefDeclHeader;
310
311 if (!typeDefs.empty()) {
312 NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
313
314 // Declare all the type classes first (in case they reference each other).
315 for (const TypeDef &typeDef : typeDefs)
316 os << " class " << typeDef.getCppClassName() << ";\n";
317
318 // Declare all the typedefs.
319 for (const TypeDef &typeDef : typeDefs)
320 emitTypeDefDecl(typeDef, os);
321 }
322
323 return false;
324 }
325
326 //===----------------------------------------------------------------------===//
327 // GEN: TypeDef list
328 //===----------------------------------------------------------------------===//
329
emitTypeDefList(SmallVectorImpl<TypeDef> & typeDefs,raw_ostream & os)330 static void emitTypeDefList(SmallVectorImpl<TypeDef> &typeDefs,
331 raw_ostream &os) {
332 IfDefScope scope("GET_TYPEDEF_LIST", os);
333 for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) {
334 os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName();
335 if (i < typeDefs.end() - 1)
336 os << ",\n";
337 else
338 os << "\n";
339 }
340 }
341
342 //===----------------------------------------------------------------------===//
343 // GEN: TypeDef definitions
344 //===----------------------------------------------------------------------===//
345
346 /// Beginning of storage class.
347 /// {0}: Storage class namespace.
348 /// {1}: Storage class c++ name.
349 /// {2}: Parameters parameters.
350 /// {3}: Parameter initializer string.
351 /// {4}: Parameter name list.
352 /// {5}: Parameter types.
353 static const char *const typeDefStorageClassBegin = R"(
354 namespace {0} {{
355 struct {1} : public ::mlir::TypeStorage {{
356 {1} ({2})
357 : {3} {{ }
358
359 /// The hash key for this storage is a pair of the integer and type params.
360 using KeyTy = std::tuple<{5}>;
361
362 /// Define the comparison function for the key type.
363 bool operator==(const KeyTy &key) const {{
364 return key == KeyTy({4});
365 }
366 )";
367
368 /// The storage class' constructor template.
369 /// {0}: storage class name.
370 static const char *const typeDefStorageClassConstructorBegin = R"(
371 /// Define a construction method for creating a new instance of this storage.
372 static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{
373 )";
374
375 /// The storage class' constructor return template.
376 /// {0}: storage class name.
377 /// {1}: list of parameters.
378 static const char *const typeDefStorageClassConstructorReturn = R"(
379 return new (allocator.allocate<{0}>())
380 {0}({1});
381 }
382 )";
383
384 /// Use tgfmt to emit custom allocation code for each parameter, if necessary.
emitParameterAllocationCode(TypeDef & typeDef,raw_ostream & os)385 static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
386 SmallVector<TypeParameter, 4> parameters;
387 typeDef.getParameters(parameters);
388 auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
389 for (TypeParameter ¶meter : parameters) {
390 auto allocCode = parameter.getAllocator();
391 if (allocCode) {
392 fmtCtxt.withSelf(parameter.getName());
393 fmtCtxt.addSubst("_dst", parameter.getName());
394 os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n";
395 }
396 }
397 }
398
399 /// Emit the storage class code for type 'typeDef'.
400 /// This includes (in-order):
401 /// 1) typeDefStorageClassBegin, which includes:
402 /// - The class constructor.
403 /// - The KeyTy definition.
404 /// - The equality (==) operator.
405 /// 2) The hashKey method.
406 /// 3) The construct method.
407 /// 4) The list of parameters as the storage class member variables.
emitStorageClass(TypeDef typeDef,raw_ostream & os)408 static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
409 SmallVector<TypeParameter, 4> parameters;
410 typeDef.getParameters(parameters);
411
412 // Initialize a bunch of variables to be used later on.
413 auto parameterNames = map_range(
414 parameters, [](TypeParameter parameter) { return parameter.getName(); });
415 auto parameterTypes = map_range(parameters, [](TypeParameter parameter) {
416 return parameter.getCppType();
417 });
418 auto parameterList = join(parameterNames, ", ");
419 auto parameterTypeList = join(parameterTypes, ", ");
420
421 // 1) Emit most of the storage class up until the hashKey body.
422 os << formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(),
423 typeDef.getStorageClassName(),
424 TypeParamCommaFormatter(
425 TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
426 parameters, /*prependComma=*/false),
427 TypeParamCommaFormatter(
428 TypeParamCommaFormatter::EmitFormat::TypeNameInitializer,
429 parameters, /*prependComma=*/false),
430 parameterList, parameterTypeList);
431
432 // 2) Emit the haskKey method.
433 os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
434 // Extract each parameter from the key.
435 for (size_t i = 0, e = parameters.size(); i < e; ++i)
436 os << llvm::formatv(" const auto &{0} = std::get<{1}>(key);\n",
437 parameters[i].getName(), i);
438 // Then combine them all. This requires all the parameters types to have a
439 // hash_value defined.
440 os << llvm::formatv(
441 " return ::llvm::hash_combine({0});\n }\n",
442 TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
443 parameters, /* prependComma */ false));
444
445 // 3) Emit the construct method.
446 if (typeDef.hasStorageCustomConstructor()) {
447 // If user wants to build the storage constructor themselves, declare it
448 // here and then they can write the definition elsewhere.
449 os << " static " << typeDef.getStorageClassName()
450 << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
451 "&key);\n";
452 } else {
453 // If not, autogenerate one.
454
455 // First, unbox the parameters.
456 os << formatv(typeDefStorageClassConstructorBegin,
457 typeDef.getStorageClassName());
458 for (size_t i = 0; i < parameters.size(); ++i) {
459 os << formatv(" auto {0} = std::get<{1}>(key);\n",
460 parameters[i].getName(), i);
461 }
462 // Second, reassign the parameter variables with allocation code, if it's
463 // specified.
464 emitParameterAllocationCode(typeDef, os);
465
466 // Last, return an allocated copy.
467 os << formatv(typeDefStorageClassConstructorReturn,
468 typeDef.getStorageClassName(), parameterList);
469 }
470
471 // 4) Emit the parameters as storage class members.
472 for (auto parameter : parameters) {
473 os << " " << parameter.getCppType() << " " << parameter.getName()
474 << ";\n";
475 }
476 os << " };\n";
477
478 os << "} // namespace " << typeDef.getStorageNamespace() << "\n";
479 }
480
481 /// Emit the parser and printer for a particular type, if they're specified.
emitParserPrinter(TypeDef typeDef,raw_ostream & os)482 void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
483 // Emit the printer code, if specified.
484 if (auto printerCode = typeDef.getPrinterCode()) {
485 // Both the mnenomic and printerCode must be defined (for parity with
486 // parserCode).
487 os << "void " << typeDef.getCppClassName()
488 << "::print(::mlir::DialectAsmPrinter &printer) const {\n";
489 if (*printerCode == "") {
490 // If no code specified, emit error.
491 PrintFatalError(typeDef.getLoc(),
492 typeDef.getName() +
493 ": printer (if specified) must have non-empty code");
494 }
495 auto fmtCtxt = FmtContext().addSubst("_printer", "printer");
496 os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
497 }
498
499 // emit a parser, if specified.
500 if (auto parserCode = typeDef.getParserCode()) {
501 // The mnenomic must be defined so the dispatcher knows how to dispatch.
502 os << "::mlir::Type " << typeDef.getCppClassName()
503 << "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &"
504 "parser) "
505 "{\n";
506 if (*parserCode == "") {
507 // if no code specified, emit error.
508 PrintFatalError(typeDef.getLoc(),
509 typeDef.getName() +
510 ": parser (if specified) must have non-empty code");
511 }
512 auto fmtCtxt =
513 FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context");
514 os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
515 }
516 }
517
518 /// Emit the builders for the given type.
emitTypeBuilderDefs(const TypeDef & typeDef,raw_ostream & os,ArrayRef<TypeParameter> typeDefParams)519 static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
520 ArrayRef<TypeParameter> typeDefParams) {
521 bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
522 StringRef typeClass = typeDef.getCppClassName();
523 if (!typeDef.skipDefaultBuilders()) {
524 os << llvm::formatv(
525 "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
526 " return Base::get(context{2});\n}\n",
527 typeClass,
528 TypeParamCommaFormatter(
529 TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams),
530 TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
531 typeDefParams));
532 if (genCheckedMethods) {
533 os << llvm::formatv(
534 "{0} {0}::getChecked(::mlir::Location loc{1}) {{\n"
535 " return Base::getChecked(loc{2});\n}\n",
536 typeClass,
537 TypeParamCommaFormatter(
538 TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
539 typeDefParams),
540 TypeParamCommaFormatter(
541 TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams));
542 }
543 }
544
545 // Generate the builders specified by the user.
546 auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context");
547 auto checkedBuilderFmtCtx = FmtContext()
548 .addSubst("_loc", "loc")
549 .addSubst("_ctxt", "loc.getContext()");
550 for (const TypeBuilder &builder : typeDef.getBuilders()) {
551 Optional<StringRef> body = builder.getBody();
552 Optional<StringRef> checkedBody =
553 genCheckedMethods ? builder.getCheckedBody() : llvm::None;
554 if (!body && !checkedBody)
555 continue;
556 std::string paramStr;
557 llvm::raw_string_ostream paramOS(paramStr);
558 llvm::interleaveComma(builder.getParameters(), paramOS,
559 [&](const TypeBuilder::Parameter ¶m) {
560 // Note: TypeBuilder parameters are guaranteed to
561 // have names.
562 paramOS << param.getCppType() << " "
563 << *param.getName();
564 });
565 paramOS.flush();
566
567 // Emit the `get` variant of the builder.
568 if (body) {
569 os << llvm::formatv("{0} {0}::get(", typeClass);
570 if (!builder.hasInferredContextParameter()) {
571 os << "::mlir::MLIRContext *context";
572 if (!paramStr.empty())
573 os << ", ";
574 os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
575 tgfmt(*body, &builderFmtCtx).str());
576 } else {
577 os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, *body);
578 }
579 }
580
581 // Emit the `getChecked` variant of the builder.
582 if (checkedBody) {
583 os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc",
584 typeClass);
585 if (!paramStr.empty())
586 os << ", " << paramStr;
587 os << llvm::formatv(") {{\n {0};\n}\n",
588 tgfmt(*checkedBody, &checkedBuilderFmtCtx));
589 }
590 }
591 }
592
593 /// Print all the typedef-specific definition code.
emitTypeDefDef(const TypeDef & typeDef,raw_ostream & os)594 static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) {
595 NamespaceEmitter ns(os, typeDef.getDialect());
596
597 SmallVector<TypeParameter, 4> parameters;
598 typeDef.getParameters(parameters);
599 if (!parameters.empty()) {
600 // Emit the storage class, if requested and necessary.
601 if (typeDef.genStorageClass())
602 emitStorageClass(typeDef, os);
603
604 // Emit the builders for this type.
605 emitTypeBuilderDefs(typeDef, os, parameters);
606
607 // Generate accessor definitions only if we also generate the storage class.
608 // Otherwise, let the user define the exact accessor definition.
609 if (typeDef.genAccessors() && typeDef.genStorageClass()) {
610 // Emit the parameter accessors.
611 for (const TypeParameter ¶meter : parameters) {
612 SmallString<16> name = parameter.getName();
613 name[0] = llvm::toUpper(name[0]);
614 os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
615 parameter.getCppType(), name, parameter.getName(),
616 typeDef.getCppClassName());
617 }
618 }
619 }
620
621 // If mnemonic is specified maybe print definitions for the parser and printer
622 // code, if they're specified.
623 if (typeDef.getMnemonic())
624 emitParserPrinter(typeDef, os);
625 }
626
627 /// Emit the dialect printer/parser dispatcher. User's code should call these
628 /// functions from their dialect's print/parse methods.
emitParsePrintDispatch(ArrayRef<TypeDef> types,raw_ostream & os)629 static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
630 if (llvm::none_of(types, [](const TypeDef &type) {
631 return type.getMnemonic().hasValue();
632 })) {
633 return;
634 }
635
636 // The parser dispatch is just a list of if-elses, matching on the
637 // mnemonic and calling the class's parse function.
638 os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *"
639 "context, ::mlir::DialectAsmParser &parser, "
640 "::llvm::StringRef mnemonic) {\n";
641 for (const TypeDef &type : types) {
642 if (type.getMnemonic()) {
643 os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
644 "{0}::{1}::",
645 type.getDialect().getCppNamespace(),
646 type.getCppClassName());
647
648 // If the type has no parameters and no parser code, just invoke a normal
649 // `get`.
650 if (type.getNumParameters() == 0 && !type.getParserCode())
651 os << "get(context);\n";
652 else
653 os << "parse(context, parser);\n";
654 }
655 }
656 os << " return ::mlir::Type();\n";
657 os << "}\n\n";
658
659 // The printer dispatch uses llvm::TypeSwitch to find and call the correct
660 // printer.
661 os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
662 "type, "
663 "::mlir::DialectAsmPrinter &printer) {\n"
664 << " return ::llvm::TypeSwitch<::mlir::Type, "
665 "::mlir::LogicalResult>(type)\n";
666 for (const TypeDef &type : types) {
667 if (Optional<StringRef> mnemonic = type.getMnemonic()) {
668 StringRef cppNamespace = type.getDialect().getCppNamespace();
669 StringRef cppClassName = type.getCppClassName();
670 os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ",
671 cppNamespace, cppClassName);
672
673 // If the type has no parameters and no printer code, just print the
674 // mnemonic.
675 if (type.getNumParameters() == 0 && !type.getPrinterCode())
676 os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace,
677 cppClassName);
678 else
679 os << "t.print(printer);";
680 os << "\n return ::mlir::success();\n })\n";
681 }
682 }
683 os << " .Default([](::mlir::Type) { return ::mlir::failure(); });\n"
684 << "}\n\n";
685 }
686
687 /// Entry point for typedef definitions.
emitTypeDefDefs(const llvm::RecordKeeper & recordKeeper,raw_ostream & os)688 static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
689 raw_ostream &os) {
690 emitSourceFileHeader("TypeDef Definitions", os);
691
692 SmallVector<TypeDef, 16> typeDefs;
693 findAllTypeDefs(recordKeeper, typeDefs);
694 emitTypeDefList(typeDefs, os);
695
696 IfDefScope scope("GET_TYPEDEF_CLASSES", os);
697 emitParsePrintDispatch(typeDefs, os);
698 for (const TypeDef &typeDef : typeDefs)
699 emitTypeDefDef(typeDef, os);
700
701 return false;
702 }
703
704 //===----------------------------------------------------------------------===//
705 // GEN: TypeDef registration hooks
706 //===----------------------------------------------------------------------===//
707
708 static mlir::GenRegistration
709 genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions",
__anond31b88460b02(const llvm::RecordKeeper &records, raw_ostream &os) 710 [](const llvm::RecordKeeper &records, raw_ostream &os) {
711 return emitTypeDefDefs(records, os);
712 });
713
714 static mlir::GenRegistration
715 genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations",
__anond31b88460c02(const llvm::RecordKeeper &records, raw_ostream &os) 716 [](const llvm::RecordKeeper &records, raw_ostream &os) {
717 return emitTypeDefDecls(records, os);
718 });
719