1 //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EnumsGen generates common utility functions for enums.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 #include "llvm/TableGen/TableGenBackend.h"
22
23 using llvm::formatv;
24 using llvm::isDigit;
25 using llvm::raw_ostream;
26 using llvm::Record;
27 using llvm::RecordKeeper;
28 using llvm::StringRef;
29 using mlir::tblgen::EnumAttr;
30 using mlir::tblgen::EnumAttrCase;
31
makeIdentifier(StringRef str)32 static std::string makeIdentifier(StringRef str) {
33 if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
34 std::string newStr = std::string("_") + str.str();
35 return newStr;
36 }
37 return str.str();
38 }
39
emitEnumClass(const Record & enumDef,StringRef enumName,StringRef underlyingType,StringRef description,const std::vector<EnumAttrCase> & enumerants,raw_ostream & os)40 static void emitEnumClass(const Record &enumDef, StringRef enumName,
41 StringRef underlyingType, StringRef description,
42 const std::vector<EnumAttrCase> &enumerants,
43 raw_ostream &os) {
44 os << "// " << description << "\n";
45 os << "enum class " << enumName;
46
47 if (!underlyingType.empty())
48 os << " : " << underlyingType;
49 os << " {\n";
50
51 for (const auto &enumerant : enumerants) {
52 auto symbol = makeIdentifier(enumerant.getSymbol());
53 auto value = enumerant.getValue();
54 if (value >= 0) {
55 os << formatv(" {0} = {1},\n", symbol, value);
56 } else {
57 os << formatv(" {0},\n", symbol);
58 }
59 }
60 os << "};\n\n";
61 }
62
emitDenseMapInfo(StringRef enumName,std::string underlyingType,StringRef cppNamespace,raw_ostream & os)63 static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
64 StringRef cppNamespace, raw_ostream &os) {
65 std::string qualName = formatv("{0}::{1}", cppNamespace, enumName);
66 if (underlyingType.empty())
67 underlyingType = formatv("std::underlying_type<{0}>::type", qualName);
68
69 const char *const mapInfo = R"(
70 namespace llvm {
71 template<> struct DenseMapInfo<{0}> {{
72 using StorageInfo = llvm::DenseMapInfo<{1}>;
73
74 static inline {0} getEmptyKey() {{
75 return static_cast<{0}>(StorageInfo::getEmptyKey());
76 }
77
78 static inline {0} getTombstoneKey() {{
79 return static_cast<{0}>(StorageInfo::getTombstoneKey());
80 }
81
82 static unsigned getHashValue(const {0} &val) {{
83 return StorageInfo::getHashValue(static_cast<{1}>(val));
84 }
85
86 static bool isEqual(const {0} &lhs, const {0} &rhs) {{
87 return lhs == rhs;
88 }
89 };
90 })";
91 os << formatv(mapInfo, qualName, underlyingType);
92 os << "\n\n";
93 }
94
emitMaxValueFn(const Record & enumDef,raw_ostream & os)95 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
96 EnumAttr enumAttr(enumDef);
97 StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
98 auto enumerants = enumAttr.getAllCases();
99
100 unsigned maxEnumVal = 0;
101 for (const auto &enumerant : enumerants) {
102 int64_t value = enumerant.getValue();
103 // Avoid generating the max value function if there is an enumerant without
104 // explicit value.
105 if (value < 0)
106 return;
107
108 maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
109 }
110
111 // Emit the function to return the max enum value
112 os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
113 os << formatv(" return {0};\n", maxEnumVal);
114 os << "}\n\n";
115 }
116
117 // Returns the EnumAttrCase whose value is zero if exists; returns llvm::None
118 // otherwise.
119 static llvm::Optional<EnumAttrCase>
getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases)120 getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
121 for (auto attrCase : cases) {
122 if (attrCase.getValue() == 0)
123 return attrCase;
124 }
125 return llvm::None;
126 }
127
128 // Emits the following inline function for bit enums:
129 //
130 // inline <enum-type> operator|(<enum-type> a, <enum-type> b);
131 // inline <enum-type> operator&(<enum-type> a, <enum-type> b);
132 // inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
emitOperators(const Record & enumDef,raw_ostream & os)133 static void emitOperators(const Record &enumDef, raw_ostream &os) {
134 EnumAttr enumAttr(enumDef);
135 StringRef enumName = enumAttr.getEnumClassName();
136 std::string underlyingType = enumAttr.getUnderlyingType();
137 os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
138 << formatv(" return static_cast<{0}>("
139 "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
140 enumName, underlyingType)
141 << "}\n";
142 os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
143 << formatv(" return static_cast<{0}>("
144 "static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
145 enumName, underlyingType)
146 << "}\n";
147 os << formatv(
148 "inline bool bitEnumContains({0} bits, {0} bit) {{\n"
149 " return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
150 enumName, underlyingType)
151 << "}\n";
152 }
153
emitSymToStrFnForIntEnum(const Record & enumDef,raw_ostream & os)154 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
155 EnumAttr enumAttr(enumDef);
156 StringRef enumName = enumAttr.getEnumClassName();
157 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
158 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
159 auto enumerants = enumAttr.getAllCases();
160
161 os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
162 symToStrFnRetType);
163 os << " switch (val) {\n";
164 for (const auto &enumerant : enumerants) {
165 auto symbol = enumerant.getSymbol();
166 os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName,
167 makeIdentifier(symbol), symbol);
168 }
169 os << " }\n";
170 os << " return \"\";\n";
171 os << "}\n\n";
172 }
173
emitSymToStrFnForBitEnum(const Record & enumDef,raw_ostream & os)174 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
175 EnumAttr enumAttr(enumDef);
176 StringRef enumName = enumAttr.getEnumClassName();
177 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
178 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
179 StringRef separator = enumDef.getValueAsString("separator");
180 auto enumerants = enumAttr.getAllCases();
181 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
182
183 os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
184 symToStrFnRetType);
185
186 os << formatv(" auto val = static_cast<{0}>(symbol);\n",
187 enumAttr.getUnderlyingType());
188 if (allBitsUnsetCase) {
189 os << " // Special case for all bits unset.\n";
190 os << formatv(" if (val == 0) return \"{0}\";\n\n",
191 allBitsUnsetCase->getSymbol());
192 }
193 os << " llvm::SmallVector<llvm::StringRef, 2> strs;\n";
194 for (const auto &enumerant : enumerants) {
195 // Skip the special enumerant for None.
196 if (auto val = enumerant.getValue())
197 os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); "
198 "val &= ~{0}u; }\n",
199 val, enumerant.getSymbol());
200 }
201 // If we have unknown bit set, return an empty string to signal errors.
202 os << "\n if (val) return \"\";\n";
203 os << formatv(" return llvm::join(strs, \"{0}\");\n", separator);
204
205 os << "}\n\n";
206 }
207
emitStrToSymFnForIntEnum(const Record & enumDef,raw_ostream & os)208 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
209 EnumAttr enumAttr(enumDef);
210 StringRef enumName = enumAttr.getEnumClassName();
211 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
212 auto enumerants = enumAttr.getAllCases();
213
214 os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
215 strToSymFnName);
216 os << formatv(" return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
217 enumName);
218 for (const auto &enumerant : enumerants) {
219 auto symbol = enumerant.getSymbol();
220 os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol,
221 makeIdentifier(symbol));
222 }
223 os << " .Default(llvm::None);\n";
224 os << "}\n";
225 }
226
emitStrToSymFnForBitEnum(const Record & enumDef,raw_ostream & os)227 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
228 EnumAttr enumAttr(enumDef);
229 StringRef enumName = enumAttr.getEnumClassName();
230 std::string underlyingType = enumAttr.getUnderlyingType();
231 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
232 StringRef separator = enumDef.getValueAsString("separator");
233 auto enumerants = enumAttr.getAllCases();
234 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
235
236 os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
237 strToSymFnName);
238
239 if (allBitsUnsetCase) {
240 os << " // Special case for all bits unset.\n";
241 StringRef caseSymbol = allBitsUnsetCase->getSymbol();
242 os << formatv(" if (str == \"{1}\") return {0}::{2};\n\n", enumName,
243 caseSymbol, makeIdentifier(caseSymbol));
244 }
245
246 // Split the string to get symbols for all the bits.
247 os << " llvm::SmallVector<llvm::StringRef, 2> symbols;\n";
248 os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
249
250 os << formatv(" {0} val = 0;\n", underlyingType);
251 os << " for (auto symbol : symbols) {\n";
252
253 // Convert each symbol to the bit ordinal and set the corresponding bit.
254 os << formatv(
255 " auto bit = llvm::StringSwitch<llvm::Optional<{0}>>(symbol)\n",
256 underlyingType);
257 for (const auto &enumerant : enumerants) {
258 // Skip the special enumerant for None.
259 if (auto val = enumerant.getValue())
260 os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
261 val);
262 }
263 os.indent(6) << ".Default(llvm::None);\n";
264
265 os << " if (bit) { val |= *bit; } else { return llvm::None; }\n";
266 os << " }\n";
267
268 os << formatv(" return static_cast<{0}>(val);\n", enumName);
269 os << "}\n\n";
270 }
271
emitUnderlyingToSymFnForIntEnum(const Record & enumDef,raw_ostream & os)272 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
273 raw_ostream &os) {
274 EnumAttr enumAttr(enumDef);
275 StringRef enumName = enumAttr.getEnumClassName();
276 std::string underlyingType = enumAttr.getUnderlyingType();
277 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
278 auto enumerants = enumAttr.getAllCases();
279
280 // Avoid generating the underlying value to symbol conversion function if
281 // there is an enumerant without explicit value.
282 if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
283 return enumerant.getValue() < 0;
284 }))
285 return;
286
287 os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
288 underlyingToSymFnName,
289 underlyingType.empty() ? std::string("unsigned")
290 : underlyingType)
291 << " switch (value) {\n";
292 for (const auto &enumerant : enumerants) {
293 auto symbol = enumerant.getSymbol();
294 auto value = enumerant.getValue();
295 os << formatv(" case {0}: return {1}::{2};\n", value, enumName,
296 makeIdentifier(symbol));
297 }
298 os << " default: return llvm::None;\n"
299 << " }\n"
300 << "}\n\n";
301 }
302
emitUnderlyingToSymFnForBitEnum(const Record & enumDef,raw_ostream & os)303 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
304 raw_ostream &os) {
305 EnumAttr enumAttr(enumDef);
306 StringRef enumName = enumAttr.getEnumClassName();
307 std::string underlyingType = enumAttr.getUnderlyingType();
308 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
309 auto enumerants = enumAttr.getAllCases();
310 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
311
312 os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
313 underlyingToSymFnName, underlyingType);
314 if (allBitsUnsetCase) {
315 os << " // Special case for all bits unset.\n";
316 os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName,
317 makeIdentifier(allBitsUnsetCase->getSymbol()));
318 }
319 llvm::SmallVector<std::string, 8> values;
320 for (const auto &enumerant : enumerants) {
321 if (auto val = enumerant.getValue())
322 values.push_back(formatv("{0}u", val));
323 }
324 os << formatv(" if (value & ~({0})) return llvm::None;\n",
325 llvm::join(values, " | "));
326 os << formatv(" return static_cast<{0}>(value);\n", enumName);
327 os << "}\n";
328 }
329
emitEnumDecl(const Record & enumDef,raw_ostream & os)330 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
331 EnumAttr enumAttr(enumDef);
332 StringRef enumName = enumAttr.getEnumClassName();
333 StringRef cppNamespace = enumAttr.getCppNamespace();
334 std::string underlyingType = enumAttr.getUnderlyingType();
335 StringRef description = enumAttr.getDescription();
336 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
337 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
338 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
339 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
340 auto enumerants = enumAttr.getAllCases();
341
342 llvm::SmallVector<StringRef, 2> namespaces;
343 llvm::SplitString(cppNamespace, namespaces, "::");
344
345 for (auto ns : namespaces)
346 os << "namespace " << ns << " {\n";
347
348 // Emit the enum class definition
349 emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
350
351 // Emit conversion function declarations
352 if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
353 return enumerant.getValue() >= 0;
354 })) {
355 os << formatv(
356 "llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
357 underlyingType.empty() ? std::string("unsigned") : underlyingType);
358 }
359 os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
360 os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
361 strToSymFnName);
362
363 if (enumAttr.isBitEnum()) {
364 emitOperators(enumDef, os);
365 } else {
366 emitMaxValueFn(enumDef, os);
367 }
368
369 for (auto ns : llvm::reverse(namespaces))
370 os << "} // namespace " << ns << "\n";
371
372 // Emit DenseMapInfo for this enum class
373 emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
374 }
375
emitEnumDecls(const RecordKeeper & recordKeeper,raw_ostream & os)376 static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
377 llvm::emitSourceFileHeader("Enum Utility Declarations", os);
378
379 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
380 for (const auto *def : defs)
381 emitEnumDecl(*def, os);
382
383 return false;
384 }
385
emitEnumDef(const Record & enumDef,raw_ostream & os)386 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
387 EnumAttr enumAttr(enumDef);
388 StringRef cppNamespace = enumAttr.getCppNamespace();
389
390 llvm::SmallVector<StringRef, 2> namespaces;
391 llvm::SplitString(cppNamespace, namespaces, "::");
392
393 for (auto ns : namespaces)
394 os << "namespace " << ns << " {\n";
395
396 if (enumAttr.isBitEnum()) {
397 emitSymToStrFnForBitEnum(enumDef, os);
398 emitStrToSymFnForBitEnum(enumDef, os);
399 emitUnderlyingToSymFnForBitEnum(enumDef, os);
400 } else {
401 emitSymToStrFnForIntEnum(enumDef, os);
402 emitStrToSymFnForIntEnum(enumDef, os);
403 emitUnderlyingToSymFnForIntEnum(enumDef, os);
404 }
405
406 for (auto ns : llvm::reverse(namespaces))
407 os << "} // namespace " << ns << "\n";
408 os << "\n";
409 }
410
emitEnumDefs(const RecordKeeper & recordKeeper,raw_ostream & os)411 static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
412 llvm::emitSourceFileHeader("Enum Utility Definitions", os);
413
414 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
415 for (const auto *def : defs)
416 emitEnumDef(*def, os);
417
418 return false;
419 }
420
421 // Registers the enum utility generator to mlir-tblgen.
422 static mlir::GenRegistration
423 genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
__anone8e83dda0302(const RecordKeeper &records, raw_ostream &os) 424 [](const RecordKeeper &records, raw_ostream &os) {
425 return emitEnumDecls(records, os);
426 });
427
428 // Registers the enum utility generator to mlir-tblgen.
429 static mlir::GenRegistration
430 genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
__anone8e83dda0402(const RecordKeeper &records, raw_ostream &os) 431 [](const RecordKeeper &records, raw_ostream &os) {
432 return emitEnumDefs(records, os);
433 });
434