1 //===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/TableGen/OpClass.h"
10 
11 #include "mlir/TableGen/Format.h"
12 #include "llvm/ADT/Sequence.h"
13 #include "llvm/ADT/Twine.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/raw_ostream.h"
16 #include <unordered_set>
17 
18 #define DEBUG_TYPE "mlir-tblgen-opclass"
19 
20 using namespace mlir;
21 using namespace mlir::tblgen;
22 
23 namespace {
24 
25 // Returns space to be emitted after the given C++ `type`. return "" if the
26 // ends with '&' or '*', or is empty, else returns " ".
getSpaceAfterType(StringRef type)27 StringRef getSpaceAfterType(StringRef type) {
28   return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
29 }
30 
31 } // namespace
32 
33 //===----------------------------------------------------------------------===//
34 // OpMethodParameter definitions
35 //===----------------------------------------------------------------------===//
36 
writeTo(raw_ostream & os,bool emitDefault) const37 void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
38   if (properties & PP_Optional)
39     os << "/*optional*/";
40   os << type << getSpaceAfterType(type) << name;
41   if (emitDefault && !defaultValue.empty())
42     os << " = " << defaultValue;
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // OpMethodParameters definitions
47 //===----------------------------------------------------------------------===//
48 
49 // Factory methods to construct the correct type of `OpMethodParameters`
50 // object based on the arguments.
create()51 std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
52   return std::make_unique<OpMethodResolvedParameters>();
53 }
54 
55 std::unique_ptr<OpMethodParameters>
create(StringRef params)56 OpMethodParameters::create(StringRef params) {
57   return std::make_unique<OpMethodUnresolvedParameters>(params);
58 }
59 
60 std::unique_ptr<OpMethodParameters>
create(llvm::SmallVectorImpl<OpMethodParameter> && params)61 OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &&params) {
62   return std::make_unique<OpMethodResolvedParameters>(std::move(params));
63 }
64 
65 std::unique_ptr<OpMethodParameters>
create(StringRef type,StringRef name,StringRef defaultValue)66 OpMethodParameters::create(StringRef type, StringRef name,
67                            StringRef defaultValue) {
68   return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // OpMethodUnresolvedParameters definitions
73 //===----------------------------------------------------------------------===//
writeDeclTo(raw_ostream & os) const74 void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
75   os << parameters;
76 }
77 
writeDefTo(raw_ostream & os) const78 void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
79   // We need to remove the default values for parameters in method definition.
80   // TODO: We are using '=' and ',' as delimiters for parameter
81   // initializers. This is incorrect for initializer list with more than one
82   // element. Change to a more robust approach.
83   llvm::SmallVector<StringRef, 4> tokens;
84   StringRef params = parameters;
85   while (!params.empty()) {
86     std::pair<StringRef, StringRef> parts = params.split("=");
87     tokens.push_back(parts.first);
88     params = parts.second.split(',').second;
89   }
90   llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // OpMethodResolvedParameters definitions
95 //===----------------------------------------------------------------------===//
96 
97 // Returns true if a method with these parameters makes a method with parameters
98 // `other` redundant. This should return true only if all possible calls to the
99 // other method can be replaced by calls to this method.
makesRedundant(const OpMethodResolvedParameters & other) const100 bool OpMethodResolvedParameters::makesRedundant(
101     const OpMethodResolvedParameters &other) const {
102   const size_t otherNumParams = other.getNumParameters();
103   const size_t thisNumParams = getNumParameters();
104 
105   // All calls to the other method can be replaced this method only if this
106   // method has the same or more arguments number of arguments as the other, and
107   // the common arguments have the same type.
108   if (thisNumParams < otherNumParams)
109     return false;
110   for (int idx : llvm::seq<int>(0, otherNumParams))
111     if (parameters[idx].getType() != other.parameters[idx].getType())
112       return false;
113 
114   // If all the common arguments have the same type, we can elide the other
115   // method if this method has the same number of arguments as other or the
116   // first argument after the common ones has a default value (and by C++
117   // requirement, all the later ones will also have a default value).
118   return thisNumParams == otherNumParams ||
119          parameters[otherNumParams].hasDefaultValue();
120 }
121 
writeDeclTo(raw_ostream & os) const122 void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
123   llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
124     param.writeDeclTo(os);
125   });
126 }
127 
writeDefTo(raw_ostream & os) const128 void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
129   llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
130     param.writeDefTo(os);
131   });
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // OpMethodSignature definitions
136 //===----------------------------------------------------------------------===//
137 
138 // Returns if a method with this signature makes a method with `other` signature
139 // redundant. Only supports resolved parameters.
makesRedundant(const OpMethodSignature & other) const140 bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
141   if (methodName != other.methodName)
142     return false;
143   auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
144   auto *resolvedOther =
145       dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
146   if (resolvedThis && resolvedOther)
147     return resolvedThis->makesRedundant(*resolvedOther);
148   return false;
149 }
150 
writeDeclTo(raw_ostream & os) const151 void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
152   os << returnType << getSpaceAfterType(returnType) << methodName << "(";
153   parameters->writeDeclTo(os);
154   os << ")";
155 }
156 
writeDefTo(raw_ostream & os,StringRef namePrefix) const157 void OpMethodSignature::writeDefTo(raw_ostream &os,
158                                    StringRef namePrefix) const {
159   os << returnType << getSpaceAfterType(returnType) << namePrefix
160      << (namePrefix.empty() ? "" : "::") << methodName << "(";
161   parameters->writeDefTo(os);
162   os << ")";
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // OpMethodBody definitions
167 //===----------------------------------------------------------------------===//
168 
OpMethodBody(bool declOnly)169 OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
170 
operator <<(Twine content)171 OpMethodBody &OpMethodBody::operator<<(Twine content) {
172   if (isEffective)
173     body.append(content.str());
174   return *this;
175 }
176 
operator <<(int content)177 OpMethodBody &OpMethodBody::operator<<(int content) {
178   if (isEffective)
179     body.append(std::to_string(content));
180   return *this;
181 }
182 
operator <<(const FmtObjectBase & content)183 OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
184   if (isEffective)
185     body.append(content.str());
186   return *this;
187 }
188 
writeTo(raw_ostream & os) const189 void OpMethodBody::writeTo(raw_ostream &os) const {
190   auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
191   os << bodyRef;
192   if (bodyRef.empty() || bodyRef.back() != '\n')
193     os << "\n";
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // OpMethod definitions
198 //===----------------------------------------------------------------------===//
199 
writeDeclTo(raw_ostream & os) const200 void OpMethod::writeDeclTo(raw_ostream &os) const {
201   os.indent(2);
202   if (isStatic())
203     os << "static ";
204   if ((properties & MP_Constexpr) == MP_Constexpr)
205     os << "constexpr ";
206   methodSignature.writeDeclTo(os);
207   if (!isInline()) {
208     os << ";";
209   } else {
210     os << " {\n";
211     methodBody.writeTo(os.indent(2));
212     os.indent(2) << "}";
213   }
214 }
215 
writeDefTo(raw_ostream & os,StringRef namePrefix) const216 void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
217   // Do not write definition if the method is decl only.
218   if (properties & MP_Declaration)
219     return;
220   // Do not generate separate definition for inline method
221   if (isInline())
222     return;
223   methodSignature.writeDefTo(os, namePrefix);
224   os << " {\n";
225   methodBody.writeTo(os);
226   os << "}";
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // OpConstructor definitions
231 //===----------------------------------------------------------------------===//
232 
addMemberInitializer(StringRef name,StringRef value)233 void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
234   memberInitializers.append(std::string(llvm::formatv(
235       "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
236 }
237 
writeDefTo(raw_ostream & os,StringRef namePrefix) const238 void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
239   // Do not write definition if the method is decl only.
240   if (properties & MP_Declaration)
241     return;
242 
243   methodSignature.writeDefTo(os, namePrefix);
244   os << " " << memberInitializers << " {\n";
245   methodBody.writeTo(os);
246   os << "}";
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // Class definitions
251 //===----------------------------------------------------------------------===//
252 
Class(StringRef name)253 Class::Class(StringRef name) : className(name) {}
254 
newField(StringRef type,StringRef name,StringRef defaultValue)255 void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
256   std::string varName = formatv("{0} {1}", type, name).str();
257   std::string field = defaultValue.empty()
258                           ? varName
259                           : formatv("{0} = {1}", varName, defaultValue).str();
260   fields.push_back(std::move(field));
261 }
writeDeclTo(raw_ostream & os) const262 void Class::writeDeclTo(raw_ostream &os) const {
263   bool hasPrivateMethod = false;
264   os << "class " << className << " {\n";
265   os << "public:\n";
266 
267   forAllMethods([&](const OpMethod &method) {
268     if (!method.isPrivate()) {
269       method.writeDeclTo(os);
270       os << '\n';
271     } else {
272       hasPrivateMethod = true;
273     }
274   });
275 
276   os << '\n';
277   os << "private:\n";
278   if (hasPrivateMethod) {
279     forAllMethods([&](const OpMethod &method) {
280       if (method.isPrivate()) {
281         method.writeDeclTo(os);
282         os << '\n';
283       }
284     });
285     os << '\n';
286   }
287 
288   for (const auto &field : fields)
289     os.indent(2) << field << ";\n";
290   os << "};\n";
291 }
292 
writeDefTo(raw_ostream & os) const293 void Class::writeDefTo(raw_ostream &os) const {
294   forAllMethods([&](const OpMethod &method) {
295     method.writeDefTo(os, className);
296     os << "\n\n";
297   });
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // OpClass definitions
302 //===----------------------------------------------------------------------===//
303 
OpClass(StringRef name,StringRef extraClassDeclaration)304 OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
305     : Class(name), extraClassDeclaration(extraClassDeclaration) {}
306 
addTrait(Twine trait)307 void OpClass::addTrait(Twine trait) {
308   auto traitStr = trait.str();
309   if (traitsSet.insert(traitStr).second)
310     traitsVec.push_back(std::move(traitStr));
311 }
312 
writeDeclTo(raw_ostream & os) const313 void OpClass::writeDeclTo(raw_ostream &os) const {
314   os << "class " << className << " : public ::mlir::Op<" << className;
315   for (const auto &trait : traitsVec)
316     os << ", " << trait;
317   os << "> {\npublic:\n"
318      << "  using Op::Op;\n"
319      << "  using Op::print;\n"
320      << "  using Adaptor = " << className << "Adaptor;\n";
321 
322   bool hasPrivateMethod = false;
323   forAllMethods([&](const OpMethod &method) {
324     if (!method.isPrivate()) {
325       method.writeDeclTo(os);
326       os << "\n";
327     } else {
328       hasPrivateMethod = true;
329     }
330   });
331 
332   // TODO: Add line control markers to make errors easier to debug.
333   if (!extraClassDeclaration.empty())
334     os << extraClassDeclaration << "\n";
335 
336   if (hasPrivateMethod) {
337     os << "\nprivate:\n";
338     forAllMethods([&](const OpMethod &method) {
339       if (method.isPrivate()) {
340         method.writeDeclTo(os);
341         os << "\n";
342       }
343     });
344   }
345 
346   os << "};\n";
347 }
348