1 //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EnumsGen generates common utility functions for enums.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 #include "llvm/TableGen/TableGenBackend.h"
22
23 using llvm::formatv;
24 using llvm::isDigit;
25 using llvm::raw_ostream;
26 using llvm::Record;
27 using llvm::RecordKeeper;
28 using llvm::StringRef;
29 using mlir::tblgen::EnumAttr;
30 using mlir::tblgen::EnumAttrCase;
31
makeIdentifier(StringRef str)32 static std::string makeIdentifier(StringRef str) {
33 if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
34 std::string newStr = std::string("_") + str.str();
35 return newStr;
36 }
37 return str.str();
38 }
39
emitEnumClass(const Record & enumDef,StringRef enumName,StringRef underlyingType,StringRef description,const std::vector<EnumAttrCase> & enumerants,raw_ostream & os)40 static void emitEnumClass(const Record &enumDef, StringRef enumName,
41 StringRef underlyingType, StringRef description,
42 const std::vector<EnumAttrCase> &enumerants,
43 raw_ostream &os) {
44 os << "// " << description << "\n";
45 os << "enum class " << enumName;
46
47 if (!underlyingType.empty())
48 os << " : " << underlyingType;
49 os << " {\n";
50
51 for (const auto &enumerant : enumerants) {
52 auto symbol = makeIdentifier(enumerant.getSymbol());
53 auto value = enumerant.getValue();
54 if (value >= 0) {
55 os << formatv(" {0} = {1},\n", symbol, value);
56 } else {
57 os << formatv(" {0},\n", symbol);
58 }
59 }
60 os << "};\n\n";
61 }
62
emitDenseMapInfo(StringRef enumName,std::string underlyingType,StringRef cppNamespace,raw_ostream & os)63 static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
64 StringRef cppNamespace, raw_ostream &os) {
65 std::string qualName =
66 std::string(formatv("{0}::{1}", cppNamespace, enumName));
67 if (underlyingType.empty())
68 underlyingType =
69 std::string(formatv("std::underlying_type<{0}>::type", qualName));
70
71 const char *const mapInfo = R"(
72 namespace llvm {
73 template<> struct DenseMapInfo<{0}> {{
74 using StorageInfo = ::llvm::DenseMapInfo<{1}>;
75
76 static inline {0} getEmptyKey() {{
77 return static_cast<{0}>(StorageInfo::getEmptyKey());
78 }
79
80 static inline {0} getTombstoneKey() {{
81 return static_cast<{0}>(StorageInfo::getTombstoneKey());
82 }
83
84 static unsigned getHashValue(const {0} &val) {{
85 return StorageInfo::getHashValue(static_cast<{1}>(val));
86 }
87
88 static bool isEqual(const {0} &lhs, const {0} &rhs) {{
89 return lhs == rhs;
90 }
91 };
92 })";
93 os << formatv(mapInfo, qualName, underlyingType);
94 os << "\n\n";
95 }
96
emitMaxValueFn(const Record & enumDef,raw_ostream & os)97 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
98 EnumAttr enumAttr(enumDef);
99 StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
100 auto enumerants = enumAttr.getAllCases();
101
102 unsigned maxEnumVal = 0;
103 for (const auto &enumerant : enumerants) {
104 int64_t value = enumerant.getValue();
105 // Avoid generating the max value function if there is an enumerant without
106 // explicit value.
107 if (value < 0)
108 return;
109
110 maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
111 }
112
113 // Emit the function to return the max enum value
114 os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
115 os << formatv(" return {0};\n", maxEnumVal);
116 os << "}\n\n";
117 }
118
119 // Returns the EnumAttrCase whose value is zero if exists; returns llvm::None
120 // otherwise.
121 static llvm::Optional<EnumAttrCase>
getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases)122 getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
123 for (auto attrCase : cases) {
124 if (attrCase.getValue() == 0)
125 return attrCase;
126 }
127 return llvm::None;
128 }
129
130 // Emits the following inline function for bit enums:
131 //
132 // inline <enum-type> operator|(<enum-type> a, <enum-type> b);
133 // inline <enum-type> operator&(<enum-type> a, <enum-type> b);
134 // inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
emitOperators(const Record & enumDef,raw_ostream & os)135 static void emitOperators(const Record &enumDef, raw_ostream &os) {
136 EnumAttr enumAttr(enumDef);
137 StringRef enumName = enumAttr.getEnumClassName();
138 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
139 os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
140 << formatv(" return static_cast<{0}>("
141 "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
142 enumName, underlyingType)
143 << "}\n";
144 os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
145 << formatv(" return static_cast<{0}>("
146 "static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
147 enumName, underlyingType)
148 << "}\n";
149 os << formatv(
150 "inline bool bitEnumContains({0} bits, {0} bit) {{\n"
151 " return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
152 enumName, underlyingType)
153 << "}\n";
154 }
155
emitSymToStrFnForIntEnum(const Record & enumDef,raw_ostream & os)156 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
157 EnumAttr enumAttr(enumDef);
158 StringRef enumName = enumAttr.getEnumClassName();
159 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
160 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
161 auto enumerants = enumAttr.getAllCases();
162
163 os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
164 symToStrFnRetType);
165 os << " switch (val) {\n";
166 for (const auto &enumerant : enumerants) {
167 auto symbol = enumerant.getSymbol();
168 auto str = enumerant.getStr();
169 os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName,
170 makeIdentifier(symbol), str);
171 }
172 os << " }\n";
173 os << " return \"\";\n";
174 os << "}\n\n";
175 }
176
emitSymToStrFnForBitEnum(const Record & enumDef,raw_ostream & os)177 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
178 EnumAttr enumAttr(enumDef);
179 StringRef enumName = enumAttr.getEnumClassName();
180 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
181 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
182 StringRef separator = enumDef.getValueAsString("separator");
183 auto enumerants = enumAttr.getAllCases();
184 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
185
186 os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
187 symToStrFnRetType);
188
189 os << formatv(" auto val = static_cast<{0}>(symbol);\n",
190 enumAttr.getUnderlyingType());
191 if (allBitsUnsetCase) {
192 os << " // Special case for all bits unset.\n";
193 os << formatv(" if (val == 0) return \"{0}\";\n\n",
194 allBitsUnsetCase->getSymbol());
195 }
196 os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
197 for (const auto &enumerant : enumerants) {
198 // Skip the special enumerant for None.
199 if (auto val = enumerant.getValue())
200 os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); "
201 "val &= ~{0}u; }\n",
202 val, enumerant.getSymbol());
203 }
204 // If we have unknown bit set, return an empty string to signal errors.
205 os << "\n if (val) return \"\";\n";
206 os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator);
207
208 os << "}\n\n";
209 }
210
emitStrToSymFnForIntEnum(const Record & enumDef,raw_ostream & os)211 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
212 EnumAttr enumAttr(enumDef);
213 StringRef enumName = enumAttr.getEnumClassName();
214 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
215 auto enumerants = enumAttr.getAllCases();
216
217 os << formatv("::llvm::Optional<{0}> {1}(::llvm::StringRef str) {{\n",
218 enumName, strToSymFnName);
219 os << formatv(" return ::llvm::StringSwitch<::llvm::Optional<{0}>>(str)\n",
220 enumName);
221 for (const auto &enumerant : enumerants) {
222 auto symbol = enumerant.getSymbol();
223 auto str = enumerant.getStr();
224 os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, str,
225 makeIdentifier(symbol));
226 }
227 os << " .Default(::llvm::None);\n";
228 os << "}\n";
229 }
230
emitStrToSymFnForBitEnum(const Record & enumDef,raw_ostream & os)231 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
232 EnumAttr enumAttr(enumDef);
233 StringRef enumName = enumAttr.getEnumClassName();
234 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
235 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
236 StringRef separator = enumDef.getValueAsString("separator");
237 auto enumerants = enumAttr.getAllCases();
238 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
239
240 os << formatv("::llvm::Optional<{0}> {1}(::llvm::StringRef str) {{\n",
241 enumName, strToSymFnName);
242
243 if (allBitsUnsetCase) {
244 os << " // Special case for all bits unset.\n";
245 StringRef caseSymbol = allBitsUnsetCase->getSymbol();
246 os << formatv(" if (str == \"{1}\") return {0}::{2};\n\n", enumName,
247 caseSymbol, makeIdentifier(caseSymbol));
248 }
249
250 // Split the string to get symbols for all the bits.
251 os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
252 os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
253
254 os << formatv(" {0} val = 0;\n", underlyingType);
255 os << " for (auto symbol : symbols) {\n";
256
257 // Convert each symbol to the bit ordinal and set the corresponding bit.
258 os << formatv(
259 " auto bit = llvm::StringSwitch<::llvm::Optional<{0}>>(symbol)\n",
260 underlyingType);
261 for (const auto &enumerant : enumerants) {
262 // Skip the special enumerant for None.
263 if (auto val = enumerant.getValue())
264 os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
265 val);
266 }
267 os.indent(6) << ".Default(::llvm::None);\n";
268
269 os << " if (bit) { val |= *bit; } else { return ::llvm::None; }\n";
270 os << " }\n";
271
272 os << formatv(" return static_cast<{0}>(val);\n", enumName);
273 os << "}\n\n";
274 }
275
emitUnderlyingToSymFnForIntEnum(const Record & enumDef,raw_ostream & os)276 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
277 raw_ostream &os) {
278 EnumAttr enumAttr(enumDef);
279 StringRef enumName = enumAttr.getEnumClassName();
280 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
281 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
282 auto enumerants = enumAttr.getAllCases();
283
284 // Avoid generating the underlying value to symbol conversion function if
285 // there is an enumerant without explicit value.
286 if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
287 return enumerant.getValue() < 0;
288 }))
289 return;
290
291 os << formatv("::llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
292 underlyingToSymFnName,
293 underlyingType.empty() ? std::string("unsigned")
294 : underlyingType)
295 << " switch (value) {\n";
296 for (const auto &enumerant : enumerants) {
297 auto symbol = enumerant.getSymbol();
298 auto value = enumerant.getValue();
299 os << formatv(" case {0}: return {1}::{2};\n", value, enumName,
300 makeIdentifier(symbol));
301 }
302 os << " default: return ::llvm::None;\n"
303 << " }\n"
304 << "}\n\n";
305 }
306
emitUnderlyingToSymFnForBitEnum(const Record & enumDef,raw_ostream & os)307 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
308 raw_ostream &os) {
309 EnumAttr enumAttr(enumDef);
310 StringRef enumName = enumAttr.getEnumClassName();
311 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
312 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
313 auto enumerants = enumAttr.getAllCases();
314 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
315
316 os << formatv("::llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
317 underlyingToSymFnName, underlyingType);
318 if (allBitsUnsetCase) {
319 os << " // Special case for all bits unset.\n";
320 os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName,
321 makeIdentifier(allBitsUnsetCase->getSymbol()));
322 }
323 llvm::SmallVector<std::string, 8> values;
324 for (const auto &enumerant : enumerants) {
325 if (auto val = enumerant.getValue())
326 values.push_back(std::string(formatv("{0}u", val)));
327 }
328 os << formatv(" if (value & ~({0})) return llvm::None;\n",
329 llvm::join(values, " | "));
330 os << formatv(" return static_cast<{0}>(value);\n", enumName);
331 os << "}\n";
332 }
333
emitEnumDecl(const Record & enumDef,raw_ostream & os)334 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
335 EnumAttr enumAttr(enumDef);
336 StringRef enumName = enumAttr.getEnumClassName();
337 StringRef cppNamespace = enumAttr.getCppNamespace();
338 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
339 StringRef description = enumAttr.getDescription();
340 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
341 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
342 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
343 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
344 auto enumerants = enumAttr.getAllCases();
345
346 llvm::SmallVector<StringRef, 2> namespaces;
347 llvm::SplitString(cppNamespace, namespaces, "::");
348
349 for (auto ns : namespaces)
350 os << "namespace " << ns << " {\n";
351
352 // Emit the enum class definition
353 emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
354
355 // Emit conversion function declarations
356 if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
357 return enumerant.getValue() >= 0;
358 })) {
359 os << formatv(
360 "::llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
361 underlyingType.empty() ? std::string("unsigned") : underlyingType);
362 }
363 os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
364 os << formatv("::llvm::Optional<{0}> {1}(::llvm::StringRef);\n", enumName,
365 strToSymFnName);
366
367 if (enumAttr.isBitEnum()) {
368 emitOperators(enumDef, os);
369 } else {
370 emitMaxValueFn(enumDef, os);
371 }
372
373 // Generate a generic `stringifyEnum` function that forwards to the method
374 // specified by the user.
375 const char *const stringifyEnumStr = R"(
376 inline {0} stringifyEnum({1} enumValue) {{
377 return {2}(enumValue);
378 }
379 )";
380 os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName);
381
382 // Generate a generic `symbolizeEnum` function that forwards to the method
383 // specified by the user.
384 const char *const symbolizeEnumStr = R"(
385 template <typename EnumType>
386 ::llvm::Optional<EnumType> symbolizeEnum(::llvm::StringRef);
387
388 template <>
389 inline ::llvm::Optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
390 return {1}(str);
391 }
392 )";
393 os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
394
395 for (auto ns : llvm::reverse(namespaces))
396 os << "} // namespace " << ns << "\n";
397
398 // Emit DenseMapInfo for this enum class
399 emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
400 }
401
emitEnumDecls(const RecordKeeper & recordKeeper,raw_ostream & os)402 static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
403 llvm::emitSourceFileHeader("Enum Utility Declarations", os);
404
405 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
406 for (const auto *def : defs)
407 emitEnumDecl(*def, os);
408
409 return false;
410 }
411
emitEnumDef(const Record & enumDef,raw_ostream & os)412 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
413 EnumAttr enumAttr(enumDef);
414 StringRef cppNamespace = enumAttr.getCppNamespace();
415
416 llvm::SmallVector<StringRef, 2> namespaces;
417 llvm::SplitString(cppNamespace, namespaces, "::");
418
419 for (auto ns : namespaces)
420 os << "namespace " << ns << " {\n";
421
422 if (enumAttr.isBitEnum()) {
423 emitSymToStrFnForBitEnum(enumDef, os);
424 emitStrToSymFnForBitEnum(enumDef, os);
425 emitUnderlyingToSymFnForBitEnum(enumDef, os);
426 } else {
427 emitSymToStrFnForIntEnum(enumDef, os);
428 emitStrToSymFnForIntEnum(enumDef, os);
429 emitUnderlyingToSymFnForIntEnum(enumDef, os);
430 }
431
432 for (auto ns : llvm::reverse(namespaces))
433 os << "} // namespace " << ns << "\n";
434 os << "\n";
435 }
436
emitEnumDefs(const RecordKeeper & recordKeeper,raw_ostream & os)437 static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
438 llvm::emitSourceFileHeader("Enum Utility Definitions", os);
439
440 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
441 for (const auto *def : defs)
442 emitEnumDef(*def, os);
443
444 return false;
445 }
446
447 // Registers the enum utility generator to mlir-tblgen.
448 static mlir::GenRegistration
449 genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
__anon51e48abf0302(const RecordKeeper &records, raw_ostream &os) 450 [](const RecordKeeper &records, raw_ostream &os) {
451 return emitEnumDecls(records, os);
452 });
453
454 // Registers the enum utility generator to mlir-tblgen.
455 static mlir::GenRegistration
456 genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
__anon51e48abf0402(const RecordKeeper &records, raw_ostream &os) 457 [](const RecordKeeper &records, raw_ostream &os) {
458 return emitEnumDefs(records, os);
459 });
460