1 //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EnumsGen generates common utility functions for enums.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 #include "llvm/TableGen/TableGenBackend.h"
22 
23 using llvm::formatv;
24 using llvm::isDigit;
25 using llvm::raw_ostream;
26 using llvm::Record;
27 using llvm::RecordKeeper;
28 using llvm::StringRef;
29 using mlir::tblgen::EnumAttr;
30 using mlir::tblgen::EnumAttrCase;
31 
makeIdentifier(StringRef str)32 static std::string makeIdentifier(StringRef str) {
33   if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
34     std::string newStr = std::string("_") + str.str();
35     return newStr;
36   }
37   return str.str();
38 }
39 
emitEnumClass(const Record & enumDef,StringRef enumName,StringRef underlyingType,StringRef description,const std::vector<EnumAttrCase> & enumerants,raw_ostream & os)40 static void emitEnumClass(const Record &enumDef, StringRef enumName,
41                           StringRef underlyingType, StringRef description,
42                           const std::vector<EnumAttrCase> &enumerants,
43                           raw_ostream &os) {
44   os << "// " << description << "\n";
45   os << "enum class " << enumName;
46 
47   if (!underlyingType.empty())
48     os << " : " << underlyingType;
49   os << " {\n";
50 
51   for (const auto &enumerant : enumerants) {
52     auto symbol = makeIdentifier(enumerant.getSymbol());
53     auto value = enumerant.getValue();
54     if (value >= 0) {
55       os << formatv("  {0} = {1},\n", symbol, value);
56     } else {
57       os << formatv("  {0},\n", symbol);
58     }
59   }
60   os << "};\n\n";
61 }
62 
emitDenseMapInfo(StringRef enumName,std::string underlyingType,StringRef cppNamespace,raw_ostream & os)63 static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
64                              StringRef cppNamespace, raw_ostream &os) {
65   std::string qualName =
66       std::string(formatv("{0}::{1}", cppNamespace, enumName));
67   if (underlyingType.empty())
68     underlyingType =
69         std::string(formatv("std::underlying_type<{0}>::type", qualName));
70 
71   const char *const mapInfo = R"(
72 namespace llvm {
73 template<> struct DenseMapInfo<{0}> {{
74   using StorageInfo = ::llvm::DenseMapInfo<{1}>;
75 
76   static inline {0} getEmptyKey() {{
77     return static_cast<{0}>(StorageInfo::getEmptyKey());
78   }
79 
80   static inline {0} getTombstoneKey() {{
81     return static_cast<{0}>(StorageInfo::getTombstoneKey());
82   }
83 
84   static unsigned getHashValue(const {0} &val) {{
85     return StorageInfo::getHashValue(static_cast<{1}>(val));
86   }
87 
88   static bool isEqual(const {0} &lhs, const {0} &rhs) {{
89     return lhs == rhs;
90   }
91 };
92 })";
93   os << formatv(mapInfo, qualName, underlyingType);
94   os << "\n\n";
95 }
96 
emitMaxValueFn(const Record & enumDef,raw_ostream & os)97 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
98   EnumAttr enumAttr(enumDef);
99   StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
100   auto enumerants = enumAttr.getAllCases();
101 
102   unsigned maxEnumVal = 0;
103   for (const auto &enumerant : enumerants) {
104     int64_t value = enumerant.getValue();
105     // Avoid generating the max value function if there is an enumerant without
106     // explicit value.
107     if (value < 0)
108       return;
109 
110     maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
111   }
112 
113   // Emit the function to return the max enum value
114   os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
115   os << formatv("  return {0};\n", maxEnumVal);
116   os << "}\n\n";
117 }
118 
119 // Returns the EnumAttrCase whose value is zero if exists; returns llvm::None
120 // otherwise.
121 static llvm::Optional<EnumAttrCase>
getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases)122 getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
123   for (auto attrCase : cases) {
124     if (attrCase.getValue() == 0)
125       return attrCase;
126   }
127   return llvm::None;
128 }
129 
130 // Emits the following inline function for bit enums:
131 //
132 // inline <enum-type> operator|(<enum-type> a, <enum-type> b);
133 // inline <enum-type> operator&(<enum-type> a, <enum-type> b);
134 // inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
emitOperators(const Record & enumDef,raw_ostream & os)135 static void emitOperators(const Record &enumDef, raw_ostream &os) {
136   EnumAttr enumAttr(enumDef);
137   StringRef enumName = enumAttr.getEnumClassName();
138   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
139   os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
140      << formatv("  return static_cast<{0}>("
141                 "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
142                 enumName, underlyingType)
143      << "}\n";
144   os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
145      << formatv("  return static_cast<{0}>("
146                 "static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
147                 enumName, underlyingType)
148      << "}\n";
149   os << formatv(
150             "inline bool bitEnumContains({0} bits, {0} bit) {{\n"
151             "  return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
152             enumName, underlyingType)
153      << "}\n";
154 }
155 
emitSymToStrFnForIntEnum(const Record & enumDef,raw_ostream & os)156 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
157   EnumAttr enumAttr(enumDef);
158   StringRef enumName = enumAttr.getEnumClassName();
159   StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
160   StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
161   auto enumerants = enumAttr.getAllCases();
162 
163   os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
164                 symToStrFnRetType);
165   os << "  switch (val) {\n";
166   for (const auto &enumerant : enumerants) {
167     auto symbol = enumerant.getSymbol();
168     auto str = enumerant.getStr();
169     os << formatv("    case {0}::{1}: return \"{2}\";\n", enumName,
170                   makeIdentifier(symbol), str);
171   }
172   os << "  }\n";
173   os << "  return \"\";\n";
174   os << "}\n\n";
175 }
176 
emitSymToStrFnForBitEnum(const Record & enumDef,raw_ostream & os)177 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
178   EnumAttr enumAttr(enumDef);
179   StringRef enumName = enumAttr.getEnumClassName();
180   StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
181   StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
182   StringRef separator = enumDef.getValueAsString("separator");
183   auto enumerants = enumAttr.getAllCases();
184   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
185 
186   os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
187                 symToStrFnRetType);
188 
189   os << formatv("  auto val = static_cast<{0}>(symbol);\n",
190                 enumAttr.getUnderlyingType());
191   if (allBitsUnsetCase) {
192     os << "  // Special case for all bits unset.\n";
193     os << formatv("  if (val == 0) return \"{0}\";\n\n",
194                   allBitsUnsetCase->getSymbol());
195   }
196   os << "  ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
197   for (const auto &enumerant : enumerants) {
198     // Skip the special enumerant for None.
199     if (auto val = enumerant.getValue())
200       os << formatv("  if ({0}u & val) {{ strs.push_back(\"{1}\"); "
201                     "val &= ~{0}u; }\n",
202                     val, enumerant.getSymbol());
203   }
204   // If we have unknown bit set, return an empty string to signal errors.
205   os << "\n  if (val) return \"\";\n";
206   os << formatv("  return ::llvm::join(strs, \"{0}\");\n", separator);
207 
208   os << "}\n\n";
209 }
210 
emitStrToSymFnForIntEnum(const Record & enumDef,raw_ostream & os)211 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
212   EnumAttr enumAttr(enumDef);
213   StringRef enumName = enumAttr.getEnumClassName();
214   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
215   auto enumerants = enumAttr.getAllCases();
216 
217   os << formatv("::llvm::Optional<{0}> {1}(::llvm::StringRef str) {{\n",
218                 enumName, strToSymFnName);
219   os << formatv("  return ::llvm::StringSwitch<::llvm::Optional<{0}>>(str)\n",
220                 enumName);
221   for (const auto &enumerant : enumerants) {
222     auto symbol = enumerant.getSymbol();
223     auto str = enumerant.getStr();
224     os << formatv("      .Case(\"{1}\", {0}::{2})\n", enumName, str,
225                   makeIdentifier(symbol));
226   }
227   os << "      .Default(::llvm::None);\n";
228   os << "}\n";
229 }
230 
emitStrToSymFnForBitEnum(const Record & enumDef,raw_ostream & os)231 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
232   EnumAttr enumAttr(enumDef);
233   StringRef enumName = enumAttr.getEnumClassName();
234   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
235   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
236   StringRef separator = enumDef.getValueAsString("separator");
237   auto enumerants = enumAttr.getAllCases();
238   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
239 
240   os << formatv("::llvm::Optional<{0}> {1}(::llvm::StringRef str) {{\n",
241                 enumName, strToSymFnName);
242 
243   if (allBitsUnsetCase) {
244     os << "  // Special case for all bits unset.\n";
245     StringRef caseSymbol = allBitsUnsetCase->getSymbol();
246     os << formatv("  if (str == \"{1}\") return {0}::{2};\n\n", enumName,
247                   caseSymbol, makeIdentifier(caseSymbol));
248   }
249 
250   // Split the string to get symbols for all the bits.
251   os << "  ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
252   os << formatv("  str.split(symbols, \"{0}\");\n\n", separator);
253 
254   os << formatv("  {0} val = 0;\n", underlyingType);
255   os << "  for (auto symbol : symbols) {\n";
256 
257   // Convert each symbol to the bit ordinal and set the corresponding bit.
258   os << formatv(
259       "    auto bit = llvm::StringSwitch<::llvm::Optional<{0}>>(symbol)\n",
260       underlyingType);
261   for (const auto &enumerant : enumerants) {
262     // Skip the special enumerant for None.
263     if (auto val = enumerant.getValue())
264       os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
265                               val);
266   }
267   os.indent(6) << ".Default(::llvm::None);\n";
268 
269   os << "    if (bit) { val |= *bit; } else { return ::llvm::None; }\n";
270   os << "  }\n";
271 
272   os << formatv("  return static_cast<{0}>(val);\n", enumName);
273   os << "}\n\n";
274 }
275 
emitUnderlyingToSymFnForIntEnum(const Record & enumDef,raw_ostream & os)276 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
277                                             raw_ostream &os) {
278   EnumAttr enumAttr(enumDef);
279   StringRef enumName = enumAttr.getEnumClassName();
280   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
281   StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
282   auto enumerants = enumAttr.getAllCases();
283 
284   // Avoid generating the underlying value to symbol conversion function if
285   // there is an enumerant without explicit value.
286   if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
287         return enumerant.getValue() < 0;
288       }))
289     return;
290 
291   os << formatv("::llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
292                 underlyingToSymFnName,
293                 underlyingType.empty() ? std::string("unsigned")
294                                        : underlyingType)
295      << "  switch (value) {\n";
296   for (const auto &enumerant : enumerants) {
297     auto symbol = enumerant.getSymbol();
298     auto value = enumerant.getValue();
299     os << formatv("  case {0}: return {1}::{2};\n", value, enumName,
300                   makeIdentifier(symbol));
301   }
302   os << "  default: return ::llvm::None;\n"
303      << "  }\n"
304      << "}\n\n";
305 }
306 
emitUnderlyingToSymFnForBitEnum(const Record & enumDef,raw_ostream & os)307 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
308                                             raw_ostream &os) {
309   EnumAttr enumAttr(enumDef);
310   StringRef enumName = enumAttr.getEnumClassName();
311   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
312   StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
313   auto enumerants = enumAttr.getAllCases();
314   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
315 
316   os << formatv("::llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
317                 underlyingToSymFnName, underlyingType);
318   if (allBitsUnsetCase) {
319     os << "  // Special case for all bits unset.\n";
320     os << formatv("  if (value == 0) return {0}::{1};\n\n", enumName,
321                   makeIdentifier(allBitsUnsetCase->getSymbol()));
322   }
323   llvm::SmallVector<std::string, 8> values;
324   for (const auto &enumerant : enumerants) {
325     if (auto val = enumerant.getValue())
326       values.push_back(std::string(formatv("{0}u", val)));
327   }
328   os << formatv("  if (value & ~({0})) return llvm::None;\n",
329                 llvm::join(values, " | "));
330   os << formatv("  return static_cast<{0}>(value);\n", enumName);
331   os << "}\n";
332 }
333 
emitEnumDecl(const Record & enumDef,raw_ostream & os)334 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
335   EnumAttr enumAttr(enumDef);
336   StringRef enumName = enumAttr.getEnumClassName();
337   StringRef cppNamespace = enumAttr.getCppNamespace();
338   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
339   StringRef description = enumAttr.getDescription();
340   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
341   StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
342   StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
343   StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
344   auto enumerants = enumAttr.getAllCases();
345 
346   llvm::SmallVector<StringRef, 2> namespaces;
347   llvm::SplitString(cppNamespace, namespaces, "::");
348 
349   for (auto ns : namespaces)
350     os << "namespace " << ns << " {\n";
351 
352   // Emit the enum class definition
353   emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
354 
355   // Emit conversion function declarations
356   if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
357         return enumerant.getValue() >= 0;
358       })) {
359     os << formatv(
360         "::llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
361         underlyingType.empty() ? std::string("unsigned") : underlyingType);
362   }
363   os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
364   os << formatv("::llvm::Optional<{0}> {1}(::llvm::StringRef);\n", enumName,
365                 strToSymFnName);
366 
367   if (enumAttr.isBitEnum()) {
368     emitOperators(enumDef, os);
369   } else {
370     emitMaxValueFn(enumDef, os);
371   }
372 
373   // Generate a generic `stringifyEnum` function that forwards to the method
374   // specified by the user.
375   const char *const stringifyEnumStr = R"(
376 inline {0} stringifyEnum({1} enumValue) {{
377   return {2}(enumValue);
378 }
379 )";
380   os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName);
381 
382   // Generate a generic `symbolizeEnum` function that forwards to the method
383   // specified by the user.
384   const char *const symbolizeEnumStr = R"(
385 template <typename EnumType>
386 ::llvm::Optional<EnumType> symbolizeEnum(::llvm::StringRef);
387 
388 template <>
389 inline ::llvm::Optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
390   return {1}(str);
391 }
392 )";
393   os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
394 
395   for (auto ns : llvm::reverse(namespaces))
396     os << "} // namespace " << ns << "\n";
397 
398   // Emit DenseMapInfo for this enum class
399   emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
400 }
401 
emitEnumDecls(const RecordKeeper & recordKeeper,raw_ostream & os)402 static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
403   llvm::emitSourceFileHeader("Enum Utility Declarations", os);
404 
405   auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
406   for (const auto *def : defs)
407     emitEnumDecl(*def, os);
408 
409   return false;
410 }
411 
emitEnumDef(const Record & enumDef,raw_ostream & os)412 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
413   EnumAttr enumAttr(enumDef);
414   StringRef cppNamespace = enumAttr.getCppNamespace();
415 
416   llvm::SmallVector<StringRef, 2> namespaces;
417   llvm::SplitString(cppNamespace, namespaces, "::");
418 
419   for (auto ns : namespaces)
420     os << "namespace " << ns << " {\n";
421 
422   if (enumAttr.isBitEnum()) {
423     emitSymToStrFnForBitEnum(enumDef, os);
424     emitStrToSymFnForBitEnum(enumDef, os);
425     emitUnderlyingToSymFnForBitEnum(enumDef, os);
426   } else {
427     emitSymToStrFnForIntEnum(enumDef, os);
428     emitStrToSymFnForIntEnum(enumDef, os);
429     emitUnderlyingToSymFnForIntEnum(enumDef, os);
430   }
431 
432   for (auto ns : llvm::reverse(namespaces))
433     os << "} // namespace " << ns << "\n";
434   os << "\n";
435 }
436 
emitEnumDefs(const RecordKeeper & recordKeeper,raw_ostream & os)437 static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
438   llvm::emitSourceFileHeader("Enum Utility Definitions", os);
439 
440   auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
441   for (const auto *def : defs)
442     emitEnumDef(*def, os);
443 
444   return false;
445 }
446 
447 // Registers the enum utility generator to mlir-tblgen.
448 static mlir::GenRegistration
449     genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
__anon51e48abf0302(const RecordKeeper &records, raw_ostream &os) 450                  [](const RecordKeeper &records, raw_ostream &os) {
451                    return emitEnumDecls(records, os);
452                  });
453 
454 // Registers the enum utility generator to mlir-tblgen.
455 static mlir::GenRegistration
456     genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
__anon51e48abf0402(const RecordKeeper &records, raw_ostream &os) 457                 [](const RecordKeeper &records, raw_ostream &os) {
458                   return emitEnumDefs(records, os);
459                 });
460