1 //===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/TableGen/OpClass.h"
10
11 #include "mlir/TableGen/Format.h"
12 #include "llvm/ADT/Sequence.h"
13 #include "llvm/ADT/Twine.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/raw_ostream.h"
16 #include <unordered_set>
17
18 #define DEBUG_TYPE "mlir-tblgen-opclass"
19
20 using namespace mlir;
21 using namespace mlir::tblgen;
22
23 namespace {
24
25 // Returns space to be emitted after the given C++ `type`. return "" if the
26 // ends with '&' or '*', or is empty, else returns " ".
getSpaceAfterType(StringRef type)27 StringRef getSpaceAfterType(StringRef type) {
28 return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
29 }
30
31 } // namespace
32
33 //===----------------------------------------------------------------------===//
34 // OpMethodParameter definitions
35 //===----------------------------------------------------------------------===//
36
writeTo(raw_ostream & os,bool emitDefault) const37 void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
38 if (properties & PP_Optional)
39 os << "/*optional*/";
40 os << type << getSpaceAfterType(type) << name;
41 if (emitDefault && !defaultValue.empty())
42 os << " = " << defaultValue;
43 }
44
45 //===----------------------------------------------------------------------===//
46 // OpMethodParameters definitions
47 //===----------------------------------------------------------------------===//
48
49 // Factory methods to construct the correct type of `OpMethodParameters`
50 // object based on the arguments.
create()51 std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
52 return std::make_unique<OpMethodResolvedParameters>();
53 }
54
55 std::unique_ptr<OpMethodParameters>
create(StringRef params)56 OpMethodParameters::create(StringRef params) {
57 return std::make_unique<OpMethodUnresolvedParameters>(params);
58 }
59
60 std::unique_ptr<OpMethodParameters>
create(llvm::SmallVectorImpl<OpMethodParameter> && params)61 OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms) {
62 return std::make_unique<OpMethodResolvedParameters>(std::move(params));
63 }
64
65 std::unique_ptr<OpMethodParameters>
create(StringRef type,StringRef name,StringRef defaultValue)66 OpMethodParameters::create(StringRef type, StringRef name,
67 StringRef defaultValue) {
68 return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
69 }
70
71 //===----------------------------------------------------------------------===//
72 // OpMethodUnresolvedParameters definitions
73 //===----------------------------------------------------------------------===//
writeDeclTo(raw_ostream & os) const74 void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
75 os << parameters;
76 }
77
writeDefTo(raw_ostream & os) const78 void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
79 // We need to remove the default values for parameters in method definition.
80 // TODO: We are using '=' and ',' as delimiters for parameter
81 // initializers. This is incorrect for initializer list with more than one
82 // element. Change to a more robust approach.
83 llvm::SmallVector<StringRef, 4> tokens;
84 StringRef params = parameters;
85 while (!params.empty()) {
86 std::pair<StringRef, StringRef> parts = params.split("=");
87 tokens.push_back(parts.first);
88 params = parts.second.split(',').second;
89 }
90 llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
91 }
92
93 //===----------------------------------------------------------------------===//
94 // OpMethodResolvedParameters definitions
95 //===----------------------------------------------------------------------===//
96
97 // Returns true if a method with these parameters makes a method with parameters
98 // `other` redundant. This should return true only if all possible calls to the
99 // other method can be replaced by calls to this method.
makesRedundant(const OpMethodResolvedParameters & other) const100 bool OpMethodResolvedParameters::makesRedundant(
101 const OpMethodResolvedParameters &other) const {
102 const size_t otherNumParams = other.getNumParameters();
103 const size_t thisNumParams = getNumParameters();
104
105 // All calls to the other method can be replaced this method only if this
106 // method has the same or more arguments number of arguments as the other, and
107 // the common arguments have the same type.
108 if (thisNumParams < otherNumParams)
109 return false;
110 for (int idx : llvm::seq<int>(0, otherNumParams))
111 if (parameters[idx].getType() != other.parameters[idx].getType())
112 return false;
113
114 // If all the common arguments have the same type, we can elide the other
115 // method if this method has the same number of arguments as other or the
116 // first argument after the common ones has a default value (and by C++
117 // requirement, all the later ones will also have a default value).
118 return thisNumParams == otherNumParams ||
119 parameters[otherNumParams].hasDefaultValue();
120 }
121
writeDeclTo(raw_ostream & os) const122 void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
123 llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
124 param.writeDeclTo(os);
125 });
126 }
127
writeDefTo(raw_ostream & os) const128 void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
129 llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
130 param.writeDefTo(os);
131 });
132 }
133
134 //===----------------------------------------------------------------------===//
135 // OpMethodSignature definitions
136 //===----------------------------------------------------------------------===//
137
138 // Returns if a method with this signature makes a method with `other` signature
139 // redundant. Only supports resolved parameters.
makesRedundant(const OpMethodSignature & other) const140 bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
141 if (methodName != other.methodName)
142 return false;
143 auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
144 auto *resolvedOther =
145 dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
146 if (resolvedThis && resolvedOther)
147 return resolvedThis->makesRedundant(*resolvedOther);
148 return false;
149 }
150
writeDeclTo(raw_ostream & os) const151 void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
152 os << returnType << getSpaceAfterType(returnType) << methodName << "(";
153 parameters->writeDeclTo(os);
154 os << ")";
155 }
156
writeDefTo(raw_ostream & os,StringRef namePrefix) const157 void OpMethodSignature::writeDefTo(raw_ostream &os,
158 StringRef namePrefix) const {
159 os << returnType << getSpaceAfterType(returnType) << namePrefix
160 << (namePrefix.empty() ? "" : "::") << methodName << "(";
161 parameters->writeDefTo(os);
162 os << ")";
163 }
164
165 //===----------------------------------------------------------------------===//
166 // OpMethodBody definitions
167 //===----------------------------------------------------------------------===//
168
OpMethodBody(bool declOnly)169 OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
170
operator <<(Twine content)171 OpMethodBody &OpMethodBody::operator<<(Twine content) {
172 if (isEffective)
173 body.append(content.str());
174 return *this;
175 }
176
operator <<(int content)177 OpMethodBody &OpMethodBody::operator<<(int content) {
178 if (isEffective)
179 body.append(std::to_string(content));
180 return *this;
181 }
182
operator <<(const FmtObjectBase & content)183 OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
184 if (isEffective)
185 body.append(content.str());
186 return *this;
187 }
188
writeTo(raw_ostream & os) const189 void OpMethodBody::writeTo(raw_ostream &os) const {
190 auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
191 os << bodyRef;
192 if (bodyRef.empty() || bodyRef.back() != '\n')
193 os << "\n";
194 }
195
196 //===----------------------------------------------------------------------===//
197 // OpMethod definitions
198 //===----------------------------------------------------------------------===//
199
writeDeclTo(raw_ostream & os) const200 void OpMethod::writeDeclTo(raw_ostream &os) const {
201 os.indent(2);
202 if (isStatic())
203 os << "static ";
204 if ((properties & MP_Constexpr) == MP_Constexpr)
205 os << "constexpr ";
206 methodSignature.writeDeclTo(os);
207 if (!isInline()) {
208 os << ";";
209 } else {
210 os << " {\n";
211 methodBody.writeTo(os.indent(2));
212 os.indent(2) << "}";
213 }
214 }
215
writeDefTo(raw_ostream & os,StringRef namePrefix) const216 void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
217 // Do not write definition if the method is decl only.
218 if (properties & MP_Declaration)
219 return;
220 // Do not generate separate definition for inline method
221 if (isInline())
222 return;
223 methodSignature.writeDefTo(os, namePrefix);
224 os << " {\n";
225 methodBody.writeTo(os);
226 os << "}";
227 }
228
229 //===----------------------------------------------------------------------===//
230 // OpConstructor definitions
231 //===----------------------------------------------------------------------===//
232
addMemberInitializer(StringRef name,StringRef value)233 void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
234 memberInitializers.append(std::string(llvm::formatv(
235 "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
236 }
237
writeDefTo(raw_ostream & os,StringRef namePrefix) const238 void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
239 // Do not write definition if the method is decl only.
240 if (properties & MP_Declaration)
241 return;
242
243 methodSignature.writeDefTo(os, namePrefix);
244 os << " " << memberInitializers << " {\n";
245 methodBody.writeTo(os);
246 os << "}";
247 }
248
249 //===----------------------------------------------------------------------===//
250 // Class definitions
251 //===----------------------------------------------------------------------===//
252
Class(StringRef name)253 Class::Class(StringRef name) : className(name) {}
254
newField(StringRef type,StringRef name,StringRef defaultValue)255 void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
256 std::string varName = formatv("{0} {1}", type, name).str();
257 std::string field = defaultValue.empty()
258 ? varName
259 : formatv("{0} = {1}", varName, defaultValue).str();
260 fields.push_back(std::move(field));
261 }
writeDeclTo(raw_ostream & os) const262 void Class::writeDeclTo(raw_ostream &os) const {
263 bool hasPrivateMethod = false;
264 os << "class " << className << " {\n";
265 os << "public:\n";
266
267 forAllMethods([&](const OpMethod &method) {
268 if (!method.isPrivate()) {
269 method.writeDeclTo(os);
270 os << '\n';
271 } else {
272 hasPrivateMethod = true;
273 }
274 });
275
276 os << '\n';
277 os << "private:\n";
278 if (hasPrivateMethod) {
279 forAllMethods([&](const OpMethod &method) {
280 if (method.isPrivate()) {
281 method.writeDeclTo(os);
282 os << '\n';
283 }
284 });
285 os << '\n';
286 }
287
288 for (const auto &field : fields)
289 os.indent(2) << field << ";\n";
290 os << "};\n";
291 }
292
writeDefTo(raw_ostream & os) const293 void Class::writeDefTo(raw_ostream &os) const {
294 forAllMethods([&](const OpMethod &method) {
295 method.writeDefTo(os, className);
296 os << "\n\n";
297 });
298 }
299
300 //===----------------------------------------------------------------------===//
301 // OpClass definitions
302 //===----------------------------------------------------------------------===//
303
OpClass(StringRef name,StringRef extraClassDeclaration)304 OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
305 : Class(name), extraClassDeclaration(extraClassDeclaration) {}
306
addTrait(Twine trait)307 void OpClass::addTrait(Twine trait) {
308 auto traitStr = trait.str();
309 if (traitsSet.insert(traitStr).second)
310 traitsVec.push_back(std::move(traitStr));
311 }
312
writeDeclTo(raw_ostream & os) const313 void OpClass::writeDeclTo(raw_ostream &os) const {
314 os << "class " << className << " : public ::mlir::Op<" << className;
315 for (const auto &trait : traitsVec)
316 os << ", " << trait;
317 os << "> {\npublic:\n"
318 << " using Op::Op;\n"
319 << " using Op::print;\n"
320 << " using Adaptor = " << className << "Adaptor;\n";
321
322 bool hasPrivateMethod = false;
323 forAllMethods([&](const OpMethod &method) {
324 if (!method.isPrivate()) {
325 method.writeDeclTo(os);
326 os << "\n";
327 } else {
328 hasPrivateMethod = true;
329 }
330 });
331
332 // TODO: Add line control markers to make errors easier to debug.
333 if (!extraClassDeclaration.empty())
334 os << extraClassDeclaration << "\n";
335
336 if (hasPrivateMethod) {
337 os << "\nprivate:\n";
338 forAllMethods([&](const OpMethod &method) {
339 if (method.isPrivate()) {
340 method.writeDeclTo(os);
341 os << "\n";
342 }
343 });
344 }
345
346 os << "};\n";
347 }
348