1 //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EnumsGen generates common utility functions for enums.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 #include "llvm/TableGen/TableGenBackend.h"
22 
23 using llvm::formatv;
24 using llvm::isDigit;
25 using llvm::raw_ostream;
26 using llvm::Record;
27 using llvm::RecordKeeper;
28 using llvm::StringRef;
29 using mlir::tblgen::EnumAttr;
30 using mlir::tblgen::EnumAttrCase;
31 
makeIdentifier(StringRef str)32 static std::string makeIdentifier(StringRef str) {
33   if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
34     std::string newStr = std::string("_") + str.str();
35     return newStr;
36   }
37   return str.str();
38 }
39 
emitEnumClass(const Record & enumDef,StringRef enumName,StringRef underlyingType,StringRef description,const std::vector<EnumAttrCase> & enumerants,raw_ostream & os)40 static void emitEnumClass(const Record &enumDef, StringRef enumName,
41                           StringRef underlyingType, StringRef description,
42                           const std::vector<EnumAttrCase> &enumerants,
43                           raw_ostream &os) {
44   os << "// " << description << "\n";
45   os << "enum class " << enumName;
46 
47   if (!underlyingType.empty())
48     os << " : " << underlyingType;
49   os << " {\n";
50 
51   for (const auto &enumerant : enumerants) {
52     auto symbol = makeIdentifier(enumerant.getSymbol());
53     auto value = enumerant.getValue();
54     if (value >= 0) {
55       os << formatv("  {0} = {1},\n", symbol, value);
56     } else {
57       os << formatv("  {0},\n", symbol);
58     }
59   }
60   os << "};\n\n";
61 }
62 
emitDenseMapInfo(StringRef enumName,std::string underlyingType,StringRef cppNamespace,raw_ostream & os)63 static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
64                              StringRef cppNamespace, raw_ostream &os) {
65   std::string qualName = formatv("{0}::{1}", cppNamespace, enumName);
66   if (underlyingType.empty())
67     underlyingType = formatv("std::underlying_type<{0}>::type", qualName);
68 
69   const char *const mapInfo = R"(
70 namespace llvm {
71 template<> struct DenseMapInfo<{0}> {{
72   using StorageInfo = llvm::DenseMapInfo<{1}>;
73 
74   static inline {0} getEmptyKey() {{
75     return static_cast<{0}>(StorageInfo::getEmptyKey());
76   }
77 
78   static inline {0} getTombstoneKey() {{
79     return static_cast<{0}>(StorageInfo::getTombstoneKey());
80   }
81 
82   static unsigned getHashValue(const {0} &val) {{
83     return StorageInfo::getHashValue(static_cast<{1}>(val));
84   }
85 
86   static bool isEqual(const {0} &lhs, const {0} &rhs) {{
87     return lhs == rhs;
88   }
89 };
90 })";
91   os << formatv(mapInfo, qualName, underlyingType);
92   os << "\n\n";
93 }
94 
emitMaxValueFn(const Record & enumDef,raw_ostream & os)95 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
96   EnumAttr enumAttr(enumDef);
97   StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
98   auto enumerants = enumAttr.getAllCases();
99 
100   unsigned maxEnumVal = 0;
101   for (const auto &enumerant : enumerants) {
102     int64_t value = enumerant.getValue();
103     // Avoid generating the max value function if there is an enumerant without
104     // explicit value.
105     if (value < 0)
106       return;
107 
108     maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
109   }
110 
111   // Emit the function to return the max enum value
112   os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
113   os << formatv("  return {0};\n", maxEnumVal);
114   os << "}\n\n";
115 }
116 
117 // Returns the EnumAttrCase whose value is zero if exists; returns llvm::None
118 // otherwise.
119 static llvm::Optional<EnumAttrCase>
getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases)120 getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
121   for (auto attrCase : cases) {
122     if (attrCase.getValue() == 0)
123       return attrCase;
124   }
125   return llvm::None;
126 }
127 
128 // Emits the following inline function for bit enums:
129 //
130 // inline <enum-type> operator|(<enum-type> a, <enum-type> b);
131 // inline <enum-type> operator&(<enum-type> a, <enum-type> b);
132 // inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
emitOperators(const Record & enumDef,raw_ostream & os)133 static void emitOperators(const Record &enumDef, raw_ostream &os) {
134   EnumAttr enumAttr(enumDef);
135   StringRef enumName = enumAttr.getEnumClassName();
136   std::string underlyingType = enumAttr.getUnderlyingType();
137   os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
138      << formatv("  return static_cast<{0}>("
139                 "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
140                 enumName, underlyingType)
141      << "}\n";
142   os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
143      << formatv("  return static_cast<{0}>("
144                 "static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
145                 enumName, underlyingType)
146      << "}\n";
147   os << formatv(
148             "inline bool bitEnumContains({0} bits, {0} bit) {{\n"
149             "  return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
150             enumName, underlyingType)
151      << "}\n";
152 }
153 
emitSymToStrFnForIntEnum(const Record & enumDef,raw_ostream & os)154 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
155   EnumAttr enumAttr(enumDef);
156   StringRef enumName = enumAttr.getEnumClassName();
157   StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
158   StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
159   auto enumerants = enumAttr.getAllCases();
160 
161   os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
162                 symToStrFnRetType);
163   os << "  switch (val) {\n";
164   for (const auto &enumerant : enumerants) {
165     auto symbol = enumerant.getSymbol();
166     os << formatv("    case {0}::{1}: return \"{2}\";\n", enumName,
167                   makeIdentifier(symbol), symbol);
168   }
169   os << "  }\n";
170   os << "  return \"\";\n";
171   os << "}\n\n";
172 }
173 
emitSymToStrFnForBitEnum(const Record & enumDef,raw_ostream & os)174 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
175   EnumAttr enumAttr(enumDef);
176   StringRef enumName = enumAttr.getEnumClassName();
177   StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
178   StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
179   StringRef separator = enumDef.getValueAsString("separator");
180   auto enumerants = enumAttr.getAllCases();
181   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
182 
183   os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
184                 symToStrFnRetType);
185 
186   os << formatv("  auto val = static_cast<{0}>(symbol);\n",
187                 enumAttr.getUnderlyingType());
188   if (allBitsUnsetCase) {
189     os << "  // Special case for all bits unset.\n";
190     os << formatv("  if (val == 0) return \"{0}\";\n\n",
191                   allBitsUnsetCase->getSymbol());
192   }
193   os << "  llvm::SmallVector<llvm::StringRef, 2> strs;\n";
194   for (const auto &enumerant : enumerants) {
195     // Skip the special enumerant for None.
196     if (auto val = enumerant.getValue())
197       os << formatv("  if ({0}u & val) {{ strs.push_back(\"{1}\"); "
198                     "val &= ~{0}u; }\n",
199                     val, enumerant.getSymbol());
200   }
201   // If we have unknown bit set, return an empty string to signal errors.
202   os << "\n  if (val) return \"\";\n";
203   os << formatv("  return llvm::join(strs, \"{0}\");\n", separator);
204 
205   os << "}\n\n";
206 }
207 
emitStrToSymFnForIntEnum(const Record & enumDef,raw_ostream & os)208 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
209   EnumAttr enumAttr(enumDef);
210   StringRef enumName = enumAttr.getEnumClassName();
211   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
212   auto enumerants = enumAttr.getAllCases();
213 
214   os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
215                 strToSymFnName);
216   os << formatv("  return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
217                 enumName);
218   for (const auto &enumerant : enumerants) {
219     auto symbol = enumerant.getSymbol();
220     os << formatv("      .Case(\"{1}\", {0}::{2})\n", enumName, symbol,
221                   makeIdentifier(symbol));
222   }
223   os << "      .Default(llvm::None);\n";
224   os << "}\n";
225 }
226 
emitStrToSymFnForBitEnum(const Record & enumDef,raw_ostream & os)227 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
228   EnumAttr enumAttr(enumDef);
229   StringRef enumName = enumAttr.getEnumClassName();
230   std::string underlyingType = enumAttr.getUnderlyingType();
231   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
232   StringRef separator = enumDef.getValueAsString("separator");
233   auto enumerants = enumAttr.getAllCases();
234   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
235 
236   os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
237                 strToSymFnName);
238 
239   if (allBitsUnsetCase) {
240     os << "  // Special case for all bits unset.\n";
241     StringRef caseSymbol = allBitsUnsetCase->getSymbol();
242     os << formatv("  if (str == \"{1}\") return {0}::{2};\n\n", enumName,
243                   caseSymbol, makeIdentifier(caseSymbol));
244   }
245 
246   // Split the string to get symbols for all the bits.
247   os << "  llvm::SmallVector<llvm::StringRef, 2> symbols;\n";
248   os << formatv("  str.split(symbols, \"{0}\");\n\n", separator);
249 
250   os << formatv("  {0} val = 0;\n", underlyingType);
251   os << "  for (auto symbol : symbols) {\n";
252 
253   // Convert each symbol to the bit ordinal and set the corresponding bit.
254   os << formatv(
255       "    auto bit = llvm::StringSwitch<llvm::Optional<{0}>>(symbol)\n",
256       underlyingType);
257   for (const auto &enumerant : enumerants) {
258     // Skip the special enumerant for None.
259     if (auto val = enumerant.getValue())
260       os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
261                               val);
262   }
263   os.indent(6) << ".Default(llvm::None);\n";
264 
265   os << "    if (bit) { val |= *bit; } else { return llvm::None; }\n";
266   os << "  }\n";
267 
268   os << formatv("  return static_cast<{0}>(val);\n", enumName);
269   os << "}\n\n";
270 }
271 
emitUnderlyingToSymFnForIntEnum(const Record & enumDef,raw_ostream & os)272 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
273                                             raw_ostream &os) {
274   EnumAttr enumAttr(enumDef);
275   StringRef enumName = enumAttr.getEnumClassName();
276   std::string underlyingType = enumAttr.getUnderlyingType();
277   StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
278   auto enumerants = enumAttr.getAllCases();
279 
280   // Avoid generating the underlying value to symbol conversion function if
281   // there is an enumerant without explicit value.
282   if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
283         return enumerant.getValue() < 0;
284       }))
285     return;
286 
287   os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
288                 underlyingToSymFnName,
289                 underlyingType.empty() ? std::string("unsigned")
290                                        : underlyingType)
291      << "  switch (value) {\n";
292   for (const auto &enumerant : enumerants) {
293     auto symbol = enumerant.getSymbol();
294     auto value = enumerant.getValue();
295     os << formatv("  case {0}: return {1}::{2};\n", value, enumName,
296                   makeIdentifier(symbol));
297   }
298   os << "  default: return llvm::None;\n"
299      << "  }\n"
300      << "}\n\n";
301 }
302 
emitUnderlyingToSymFnForBitEnum(const Record & enumDef,raw_ostream & os)303 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
304                                             raw_ostream &os) {
305   EnumAttr enumAttr(enumDef);
306   StringRef enumName = enumAttr.getEnumClassName();
307   std::string underlyingType = enumAttr.getUnderlyingType();
308   StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
309   auto enumerants = enumAttr.getAllCases();
310   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
311 
312   os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
313                 underlyingToSymFnName, underlyingType);
314   if (allBitsUnsetCase) {
315     os << "  // Special case for all bits unset.\n";
316     os << formatv("  if (value == 0) return {0}::{1};\n\n", enumName,
317                   makeIdentifier(allBitsUnsetCase->getSymbol()));
318   }
319   llvm::SmallVector<std::string, 8> values;
320   for (const auto &enumerant : enumerants) {
321     if (auto val = enumerant.getValue())
322       values.push_back(formatv("{0}u", val));
323   }
324   os << formatv("  if (value & ~({0})) return llvm::None;\n",
325                 llvm::join(values, " | "));
326   os << formatv("  return static_cast<{0}>(value);\n", enumName);
327   os << "}\n";
328 }
329 
emitEnumDecl(const Record & enumDef,raw_ostream & os)330 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
331   EnumAttr enumAttr(enumDef);
332   StringRef enumName = enumAttr.getEnumClassName();
333   StringRef cppNamespace = enumAttr.getCppNamespace();
334   std::string underlyingType = enumAttr.getUnderlyingType();
335   StringRef description = enumAttr.getDescription();
336   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
337   StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
338   StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
339   StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
340   auto enumerants = enumAttr.getAllCases();
341 
342   llvm::SmallVector<StringRef, 2> namespaces;
343   llvm::SplitString(cppNamespace, namespaces, "::");
344 
345   for (auto ns : namespaces)
346     os << "namespace " << ns << " {\n";
347 
348   // Emit the enum class definition
349   emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
350 
351   // Emit conversion function declarations
352   if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
353         return enumerant.getValue() >= 0;
354       })) {
355     os << formatv(
356         "llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
357         underlyingType.empty() ? std::string("unsigned") : underlyingType);
358   }
359   os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
360   os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
361                 strToSymFnName);
362 
363   if (enumAttr.isBitEnum()) {
364     emitOperators(enumDef, os);
365   } else {
366     emitMaxValueFn(enumDef, os);
367   }
368 
369   for (auto ns : llvm::reverse(namespaces))
370     os << "} // namespace " << ns << "\n";
371 
372   // Emit DenseMapInfo for this enum class
373   emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
374 }
375 
emitEnumDecls(const RecordKeeper & recordKeeper,raw_ostream & os)376 static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
377   llvm::emitSourceFileHeader("Enum Utility Declarations", os);
378 
379   auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
380   for (const auto *def : defs)
381     emitEnumDecl(*def, os);
382 
383   return false;
384 }
385 
emitEnumDef(const Record & enumDef,raw_ostream & os)386 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
387   EnumAttr enumAttr(enumDef);
388   StringRef cppNamespace = enumAttr.getCppNamespace();
389 
390   llvm::SmallVector<StringRef, 2> namespaces;
391   llvm::SplitString(cppNamespace, namespaces, "::");
392 
393   for (auto ns : namespaces)
394     os << "namespace " << ns << " {\n";
395 
396   if (enumAttr.isBitEnum()) {
397     emitSymToStrFnForBitEnum(enumDef, os);
398     emitStrToSymFnForBitEnum(enumDef, os);
399     emitUnderlyingToSymFnForBitEnum(enumDef, os);
400   } else {
401     emitSymToStrFnForIntEnum(enumDef, os);
402     emitStrToSymFnForIntEnum(enumDef, os);
403     emitUnderlyingToSymFnForIntEnum(enumDef, os);
404   }
405 
406   for (auto ns : llvm::reverse(namespaces))
407     os << "} // namespace " << ns << "\n";
408   os << "\n";
409 }
410 
emitEnumDefs(const RecordKeeper & recordKeeper,raw_ostream & os)411 static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
412   llvm::emitSourceFileHeader("Enum Utility Definitions", os);
413 
414   auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
415   for (const auto *def : defs)
416     emitEnumDef(*def, os);
417 
418   return false;
419 }
420 
421 // Registers the enum utility generator to mlir-tblgen.
422 static mlir::GenRegistration
423     genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
__anone8e83dda0302(const RecordKeeper &records, raw_ostream &os) 424                  [](const RecordKeeper &records, raw_ostream &os) {
425                    return emitEnumDecls(records, os);
426                  });
427 
428 // Registers the enum utility generator to mlir-tblgen.
429 static mlir::GenRegistration
430     genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
__anone8e83dda0402(const RecordKeeper &records, raw_ostream &os) 431                 [](const RecordKeeper &records, raw_ostream &os) {
432                   return emitEnumDefs(records, os);
433                 });
434