1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Attribute wrapper to simplify using TableGen Record defining a MLIR
10 // Attribute.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/TableGen/Record.h"
17 
18 using namespace mlir;
19 using namespace mlir::tblgen;
20 
21 using llvm::DefInit;
22 using llvm::Init;
23 using llvm::Record;
24 using llvm::StringInit;
25 
26 // Returns the initializer's value as string if the given TableGen initializer
27 // is a code or string initializer. Returns the empty StringRef otherwise.
getValueAsString(const Init * init)28 static StringRef getValueAsString(const Init *init) {
29   if (const auto *str = dyn_cast<StringInit>(init))
30     return str->getValue().trim();
31   return {};
32 }
33 
AttrConstraint(const Record * record)34 AttrConstraint::AttrConstraint(const Record *record)
35     : Constraint(Constraint::CK_Attr, record) {
36   assert(isSubClassOf("AttrConstraint") &&
37          "must be subclass of TableGen 'AttrConstraint' class");
38 }
39 
isSubClassOf(StringRef className) const40 bool AttrConstraint::isSubClassOf(StringRef className) const {
41   return def->isSubClassOf(className);
42 }
43 
Attribute(const Record * record)44 Attribute::Attribute(const Record *record) : AttrConstraint(record) {
45   assert(record->isSubClassOf("Attr") &&
46          "must be subclass of TableGen 'Attr' class");
47 }
48 
Attribute(const DefInit * init)49 Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
50 
isDerivedAttr() const51 bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
52 
isTypeAttr() const53 bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
54 
isSymbolRefAttr() const55 bool Attribute::isSymbolRefAttr() const {
56   StringRef defName = def->getName();
57   if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
58     return true;
59   return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
60 }
61 
isEnumAttr() const62 bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
63 
getStorageType() const64 StringRef Attribute::getStorageType() const {
65   const auto *init = def->getValueInit("storageType");
66   auto type = getValueAsString(init);
67   if (type.empty())
68     return "Attribute";
69   return type;
70 }
71 
getReturnType() const72 StringRef Attribute::getReturnType() const {
73   const auto *init = def->getValueInit("returnType");
74   return getValueAsString(init);
75 }
76 
77 // Return the type constraint corresponding to the type of this attribute, or
78 // None if this is not a TypedAttr.
getValueType() const79 llvm::Optional<Type> Attribute::getValueType() const {
80   if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
81     return Type(defInit->getDef());
82   return llvm::None;
83 }
84 
getConvertFromStorageCall() const85 StringRef Attribute::getConvertFromStorageCall() const {
86   const auto *init = def->getValueInit("convertFromStorage");
87   return getValueAsString(init);
88 }
89 
isConstBuildable() const90 bool Attribute::isConstBuildable() const {
91   const auto *init = def->getValueInit("constBuilderCall");
92   return !getValueAsString(init).empty();
93 }
94 
getConstBuilderTemplate() const95 StringRef Attribute::getConstBuilderTemplate() const {
96   const auto *init = def->getValueInit("constBuilderCall");
97   return getValueAsString(init);
98 }
99 
getBaseAttr() const100 Attribute Attribute::getBaseAttr() const {
101   if (const auto *defInit =
102           llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
103     return Attribute(defInit).getBaseAttr();
104   }
105   return *this;
106 }
107 
hasDefaultValue() const108 bool Attribute::hasDefaultValue() const {
109   const auto *init = def->getValueInit("defaultValue");
110   return !getValueAsString(init).empty();
111 }
112 
getDefaultValue() const113 StringRef Attribute::getDefaultValue() const {
114   const auto *init = def->getValueInit("defaultValue");
115   return getValueAsString(init);
116 }
117 
isOptional() const118 bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); }
119 
getAttrDefName() const120 StringRef Attribute::getAttrDefName() const {
121   if (def->isAnonymous()) {
122     return getBaseAttr().def->getName();
123   }
124   return def->getName();
125 }
126 
getDerivedCodeBody() const127 StringRef Attribute::getDerivedCodeBody() const {
128   assert(isDerivedAttr() && "only derived attribute has 'body' field");
129   return def->getValueAsString("body");
130 }
131 
getDialect() const132 Dialect Attribute::getDialect() const {
133   const llvm::RecordVal *record = def->getValue("dialect");
134   if (record && record->getValue()) {
135     if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
136       return Dialect(init->getDef());
137   }
138   return Dialect(nullptr);
139 }
140 
ConstantAttr(const DefInit * init)141 ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
142   assert(def->isSubClassOf("ConstantAttr") &&
143          "must be subclass of TableGen 'ConstantAttr' class");
144 }
145 
getAttribute() const146 Attribute ConstantAttr::getAttribute() const {
147   return Attribute(def->getValueAsDef("attr"));
148 }
149 
getConstantValue() const150 StringRef ConstantAttr::getConstantValue() const {
151   return def->getValueAsString("value");
152 }
153 
EnumAttrCase(const llvm::Record * record)154 EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
155   assert(isSubClassOf("EnumAttrCaseInfo") &&
156          "must be subclass of TableGen 'EnumAttrInfo' class");
157 }
158 
EnumAttrCase(const llvm::DefInit * init)159 EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
160     : EnumAttrCase(init->getDef()) {}
161 
isStrCase() const162 bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
163 
getSymbol() const164 StringRef EnumAttrCase::getSymbol() const {
165   return def->getValueAsString("symbol");
166 }
167 
getStr() const168 StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
169 
getValue() const170 int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
171 
getDef() const172 const llvm::Record &EnumAttrCase::getDef() const { return *def; }
173 
EnumAttr(const llvm::Record * record)174 EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
175   assert(isSubClassOf("EnumAttrInfo") &&
176          "must be subclass of TableGen 'EnumAttr' class");
177 }
178 
EnumAttr(const llvm::Record & record)179 EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
180 
EnumAttr(const llvm::DefInit * init)181 EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
182 
classof(const Attribute * attr)183 bool EnumAttr::classof(const Attribute *attr) {
184   return attr->isSubClassOf("EnumAttrInfo");
185 }
186 
isBitEnum() const187 bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
188 
getEnumClassName() const189 StringRef EnumAttr::getEnumClassName() const {
190   return def->getValueAsString("className");
191 }
192 
getCppNamespace() const193 StringRef EnumAttr::getCppNamespace() const {
194   return def->getValueAsString("cppNamespace");
195 }
196 
getUnderlyingType() const197 StringRef EnumAttr::getUnderlyingType() const {
198   return def->getValueAsString("underlyingType");
199 }
200 
getUnderlyingToSymbolFnName() const201 StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
202   return def->getValueAsString("underlyingToSymbolFnName");
203 }
204 
getStringToSymbolFnName() const205 StringRef EnumAttr::getStringToSymbolFnName() const {
206   return def->getValueAsString("stringToSymbolFnName");
207 }
208 
getSymbolToStringFnName() const209 StringRef EnumAttr::getSymbolToStringFnName() const {
210   return def->getValueAsString("symbolToStringFnName");
211 }
212 
getSymbolToStringFnRetType() const213 StringRef EnumAttr::getSymbolToStringFnRetType() const {
214   return def->getValueAsString("symbolToStringFnRetType");
215 }
216 
getMaxEnumValFnName() const217 StringRef EnumAttr::getMaxEnumValFnName() const {
218   return def->getValueAsString("maxEnumValFnName");
219 }
220 
getAllCases() const221 std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
222   const auto *inits = def->getValueAsListInit("enumerants");
223 
224   std::vector<EnumAttrCase> cases;
225   cases.reserve(inits->size());
226 
227   for (const llvm::Init *init : *inits) {
228     cases.push_back(EnumAttrCase(cast<llvm::DefInit>(init)));
229   }
230 
231   return cases;
232 }
233 
genSpecializedAttr() const234 bool EnumAttr::genSpecializedAttr() const {
235   return def->getValueAsBit("genSpecializedAttr");
236 }
237 
getBaseAttrClass() const238 llvm::Record *EnumAttr::getBaseAttrClass() const {
239   return def->getValueAsDef("baseAttrClass");
240 }
241 
getSpecializedAttrClassName() const242 StringRef EnumAttr::getSpecializedAttrClassName() const {
243   return def->getValueAsString("specializedAttrClassName");
244 }
245 
StructFieldAttr(const llvm::Record * record)246 StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
247   assert(def->isSubClassOf("StructFieldAttr") &&
248          "must be subclass of TableGen 'StructFieldAttr' class");
249 }
250 
StructFieldAttr(const llvm::Record & record)251 StructFieldAttr::StructFieldAttr(const llvm::Record &record)
252     : StructFieldAttr(&record) {}
253 
StructFieldAttr(const llvm::DefInit * init)254 StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
255     : StructFieldAttr(init->getDef()) {}
256 
getName() const257 StringRef StructFieldAttr::getName() const {
258   return def->getValueAsString("name");
259 }
260 
getType() const261 Attribute StructFieldAttr::getType() const {
262   auto init = def->getValueInit("type");
263   return Attribute(cast<llvm::DefInit>(init));
264 }
265 
StructAttr(const llvm::Record * record)266 StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
267   assert(isSubClassOf("StructAttr") &&
268          "must be subclass of TableGen 'StructAttr' class");
269 }
270 
StructAttr(const llvm::DefInit * init)271 StructAttr::StructAttr(const llvm::DefInit *init)
272     : StructAttr(init->getDef()) {}
273 
getStructClassName() const274 StringRef StructAttr::getStructClassName() const {
275   return def->getValueAsString("className");
276 }
277 
getCppNamespace() const278 StringRef StructAttr::getCppNamespace() const {
279   Dialect dialect(def->getValueAsDef("dialect"));
280   return dialect.getCppNamespace();
281 }
282 
getAllFields() const283 std::vector<StructFieldAttr> StructAttr::getAllFields() const {
284   std::vector<StructFieldAttr> attributes;
285 
286   const auto *inits = def->getValueAsListInit("fields");
287   attributes.reserve(inits->size());
288 
289   for (const llvm::Init *init : *inits) {
290     attributes.emplace_back(cast<llvm::DefInit>(init));
291   }
292 
293   return attributes;
294 }
295 
296 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
297