1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// File header and includes.
26 ///   {0} is the dialect namespace.
27 constexpr const char *fileHeader = R"Py(
28 # Autogenerated by mlir-tblgen; don't manually edit.
29 
30 from . import _cext as _ods_cext
31 from . import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context
32 _ods_ir = _ods_cext.ir
33 
34 try:
35   from . import _{0} as _ods_ext_module
36 except ImportError:
37   _ods_ext_module = None
38 
39 )Py";
40 
41 /// Template for dialect class:
42 ///   {0} is the dialect namespace.
43 constexpr const char *dialectClassTemplate = R"Py(
44 @_ods_cext.register_dialect
45 class _Dialect(_ods_ir.Dialect):
46   DIALECT_NAMESPACE = "{0}"
47   pass
48 
49 )Py";
50 
51 /// Template for operation class:
52 ///   {0} is the Python class name;
53 ///   {1} is the operation name.
54 constexpr const char *opClassTemplate = R"Py(
55 @_ods_cext.register_operation(_Dialect)
56 @_ods_extend_opview_class(_ods_ext_module)
57 class {0}(_ods_ir.OpView):
58   OPERATION_NAME = "{1}"
59 )Py";
60 
61 /// Template for class level declarations of operand and result
62 /// segment specs.
63 ///   {0} is either "OPERAND" or "RESULT"
64 ///   {1} is the segment spec
65 /// Each segment spec is either None (default) or an array of integers
66 /// where:
67 ///   1 = single element (expect non sequence operand/result)
68 ///   -1 = operand/result is a sequence corresponding to a variadic
69 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
70   _ODS_{0}_SEGMENTS = {1}
71 )Py";
72 
73 /// Template for class level declarations of the _ODS_REGIONS spec:
74 ///   {0} is the minimum number of regions
75 ///   {1} is the Python bool literal for hasNoVariadicRegions
76 constexpr const char *opClassRegionSpecTemplate = R"Py(
77   _ODS_REGIONS = ({0}, {1})
78 )Py";
79 
80 /// Template for single-element accessor:
81 ///   {0} is the name of the accessor;
82 ///   {1} is either 'operand' or 'result';
83 ///   {2} is the position in the element list.
84 constexpr const char *opSingleTemplate = R"Py(
85   @property
86   def {0}(self):
87     return self.operation.{1}s[{2}]
88 )Py";
89 
90 /// Template for single-element accessor after a variable-length group:
91 ///   {0} is the name of the accessor;
92 ///   {1} is either 'operand' or 'result';
93 ///   {2} is the total number of element groups;
94 ///   {3} is the position of the current group in the group list.
95 /// This works for both a single variadic group (non-negative length) and an
96 /// single optional element (zero length if the element is absent).
97 constexpr const char *opSingleAfterVariableTemplate = R"Py(
98   @property
99   def {0}(self):
100     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
101     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
102 )Py";
103 
104 /// Template for an optional element accessor:
105 ///   {0} is the name of the accessor;
106 ///   {1} is either 'operand' or 'result';
107 ///   {2} is the total number of element groups;
108 ///   {3} is the position of the current group in the group list.
109 constexpr const char *opOneOptionalTemplate = R"Py(
110   @property
111   def {0}(self):
112     return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None
113 )Py";
114 
115 /// Template for the variadic group accessor in the single variadic group case:
116 ///   {0} is the name of the accessor;
117 ///   {1} is either 'operand' or 'result';
118 ///   {2} is the total number of element groups;
119 ///   {3} is the position of the current group in the group list.
120 constexpr const char *opOneVariadicTemplate = R"Py(
121   @property
122   def {0}(self):
123     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
124     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
125 )Py";
126 
127 /// First part of the template for equally-sized variadic group accessor:
128 ///   {0} is the name of the accessor;
129 ///   {1} is either 'operand' or 'result';
130 ///   {2} is the total number of variadic groups;
131 ///   {3} is the number of non-variadic groups preceding the current group;
132 ///   {3} is the number of variadic groups preceding the current group.
133 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
134   @property
135   def {0}(self):
136     start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
137 
138 /// Second part of the template for equally-sized case, accessing a single
139 /// element:
140 ///   {0} is either 'operand' or 'result'.
141 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
142     return self.operation.{0}s[start]
143 )Py";
144 
145 /// Second part of the template for equally-sized case, accessing a variadic
146 /// group:
147 ///   {0} is either 'operand' or 'result'.
148 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
149     return self.operation.{0}s[start:start + pg]
150 )Py";
151 
152 /// Template for an attribute-sized group accessor:
153 ///   {0} is the name of the accessor;
154 ///   {1} is either 'operand' or 'result';
155 ///   {2} is the position of the group in the group list;
156 ///   {3} is a return suffix (expected [0] for single-element, empty for
157 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
158 constexpr const char *opVariadicSegmentTemplate = R"Py(
159   @property
160   def {0}(self):
161     {1}_range = _ods_segmented_accessor(
162          self.operation.{1}s,
163          self.operation.attributes["{1}_segment_sizes"], {2})
164     return {1}_range{3}
165 )Py";
166 
167 /// Template for a suffix when accessing an optional element in the
168 /// attribute-sized case:
169 ///   {0} is either 'operand' or 'result';
170 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
171     R"Py([0] if len({0}_range) > 0 else None)Py";
172 
173 /// Template for an operation attribute getter:
174 ///   {0} is the name of the attribute sanitized for Python;
175 ///   {1} is the Python type of the attribute;
176 ///   {2} os the original name of the attribute.
177 constexpr const char *attributeGetterTemplate = R"Py(
178   @property
179   def {0}(self):
180     return {1}(self.operation.attributes["{2}"])
181 )Py";
182 
183 /// Template for an optional operation attribute getter:
184 ///   {0} is the name of the attribute sanitized for Python;
185 ///   {1} is the Python type of the attribute;
186 ///   {2} is the original name of the attribute.
187 constexpr const char *optionalAttributeGetterTemplate = R"Py(
188   @property
189   def {0}(self):
190     if "{2}" not in self.operation.attributes:
191       return None
192     return {1}(self.operation.attributes["{2}"])
193 )Py";
194 
195 /// Template for a getter of a unit operation attribute, returns True of the
196 /// unit attribute is present, False otherwise (unit attributes have meaning
197 /// by mere presence):
198 ///    {0} is the name of the attribute sanitized for Python,
199 ///    {1} is the original name of the attribute.
200 constexpr const char *unitAttributeGetterTemplate = R"Py(
201   @property
202   def {0}(self):
203     return "{1}" in self.operation.attributes
204 )Py";
205 
206 /// Template for an operation attribute setter:
207 ///    {0} is the name of the attribute sanitized for Python;
208 ///    {1} is the original name of the attribute.
209 constexpr const char *attributeSetterTemplate = R"Py(
210   @{0}.setter
211   def {0}(self, value):
212     if value is None:
213       raise ValueError("'None' not allowed as value for mandatory attributes")
214     self.operation.attributes["{1}"] = value
215 )Py";
216 
217 /// Template for a setter of an optional operation attribute, setting to None
218 /// removes the attribute:
219 ///    {0} is the name of the attribute sanitized for Python;
220 ///    {1} is the original name of the attribute.
221 constexpr const char *optionalAttributeSetterTemplate = R"Py(
222   @{0}.setter
223   def {0}(self, value):
224     if value is not None:
225       self.operation.attributes["{1}"] = value
226     elif "{1}" in self.operation.attributes:
227       del self.operation.attributes["{1}"]
228 )Py";
229 
230 /// Template for a setter of a unit operation attribute, setting to None or
231 /// False removes the attribute:
232 ///    {0} is the name of the attribute sanitized for Python;
233 ///    {1} is the original name of the attribute.
234 constexpr const char *unitAttributeSetterTemplate = R"Py(
235   @{0}.setter
236   def {0}(self, value):
237     if bool(value):
238       self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
239     elif "{1}" in self.operation.attributes:
240       del self.operation.attributes["{1}"]
241 )Py";
242 
243 /// Template for a deleter of an optional or a unit operation attribute, removes
244 /// the attribute from the operation:
245 ///    {0} is the name of the attribute sanitized for Python;
246 ///    {1} is the original name of the attribute.
247 constexpr const char *attributeDeleterTemplate = R"Py(
248   @{0}.deleter
249   def {0}(self):
250     del self.operation.attributes["{1}"]
251 )Py";
252 
253 static llvm::cl::OptionCategory
254     clOpPythonBindingCat("Options for -gen-python-op-bindings");
255 
256 static llvm::cl::opt<std::string>
257     clDialectName("bind-dialect",
258                   llvm::cl::desc("The dialect to run the generator for"),
259                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
260 
261 using AttributeClasses = DenseMap<StringRef, StringRef>;
262 
263 /// Checks whether `str` is a Python keyword.
isPythonKeyword(StringRef str)264 static bool isPythonKeyword(StringRef str) {
265   static llvm::StringSet<> keywords(
266       {"and",   "as",     "assert",   "break", "class",  "continue",
267        "def",   "del",    "elif",     "else",  "except", "finally",
268        "for",   "from",   "global",   "if",    "import", "in",
269        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
270        "raise", "return", "try",      "while", "with",   "yield"});
271   return keywords.contains(str);
272 }
273 
274 /// Checks whether `str` would shadow a generated variable or attribute
275 /// part of the OpView API.
isODSReserved(StringRef str)276 static bool isODSReserved(StringRef str) {
277   static llvm::StringSet<> reserved(
278       {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
279        "loc", "verify", "regions", "results", "self", "operation",
280        "DIALECT_NAMESPACE", "OPERATION_NAME"});
281   return str.startswith("_ods_") || str.endswith("_ods") ||
282          reserved.contains(str);
283 }
284 
285 /// Modifies the `name` in a way that it becomes suitable for Python bindings
286 /// (does not change the `name` if it already is suitable) and returns the
287 /// modified version.
sanitizeName(StringRef name)288 static std::string sanitizeName(StringRef name) {
289   if (isPythonKeyword(name) || isODSReserved(name))
290     return (name + "_").str();
291   return name.str();
292 }
293 
attrSizedTraitForKind(const char * kind)294 static std::string attrSizedTraitForKind(const char *kind) {
295   return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
296                        llvm::StringRef(kind).take_front().upper(),
297                        llvm::StringRef(kind).drop_front());
298 }
299 
300 /// Emits accessors to "elements" of an Op definition. Currently, the supported
301 /// elements are operands and results, indicated by `kind`, which must be either
302 /// `operand` or `result` and is used verbatim in the emitted code.
emitElementAccessors(const Operator & op,raw_ostream & os,const char * kind,llvm::function_ref<unsigned (const Operator &)> getNumVariadic,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)303 static void emitElementAccessors(
304     const Operator &op, raw_ostream &os, const char *kind,
305     llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
306     llvm::function_ref<int(const Operator &)> getNumElements,
307     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
308         getElement) {
309   assert(llvm::is_contained(
310              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
311          "unsupported kind");
312 
313   // Traits indicating how to process variadic elements.
314   std::string sameSizeTrait =
315       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
316                     llvm::StringRef(kind).take_front().upper(),
317                     llvm::StringRef(kind).drop_front());
318   std::string attrSizedTrait = attrSizedTraitForKind(kind);
319 
320   unsigned numVariadic = getNumVariadic(op);
321 
322   // If there is only one variadic element group, its size can be inferred from
323   // the total number of elements. If there are none, the generation is
324   // straightforward.
325   if (numVariadic <= 1) {
326     bool seenVariableLength = false;
327     for (int i = 0, e = getNumElements(op); i < e; ++i) {
328       const NamedTypeConstraint &element = getElement(op, i);
329       if (element.isVariableLength())
330         seenVariableLength = true;
331       if (element.name.empty())
332         continue;
333       if (element.isVariableLength()) {
334         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
335                                                  : opOneVariadicTemplate,
336                             sanitizeName(element.name), kind,
337                             getNumElements(op), i);
338       } else if (seenVariableLength) {
339         os << llvm::formatv(opSingleAfterVariableTemplate,
340                             sanitizeName(element.name), kind,
341                             getNumElements(op), i);
342       } else {
343         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
344                             i);
345       }
346     }
347     return;
348   }
349 
350   // Handle the operations where variadic groups have the same size.
351   if (op.getTrait(sameSizeTrait)) {
352     int numPrecedingSimple = 0;
353     int numPrecedingVariadic = 0;
354     for (int i = 0, e = getNumElements(op); i < e; ++i) {
355       const NamedTypeConstraint &element = getElement(op, i);
356       if (!element.name.empty()) {
357         os << llvm::formatv(opVariadicEqualPrefixTemplate,
358                             sanitizeName(element.name), kind, numVariadic,
359                             numPrecedingSimple, numPrecedingVariadic);
360         os << llvm::formatv(element.isVariableLength()
361                                 ? opVariadicEqualVariadicTemplate
362                                 : opVariadicEqualSimpleTemplate,
363                             kind);
364       }
365       if (element.isVariableLength())
366         ++numPrecedingVariadic;
367       else
368         ++numPrecedingSimple;
369     }
370     return;
371   }
372 
373   // Handle the operations where the size of groups (variadic or not) is
374   // provided as an attribute. For non-variadic elements, make sure to return
375   // an element rather than a singleton container.
376   if (op.getTrait(attrSizedTrait)) {
377     for (int i = 0, e = getNumElements(op); i < e; ++i) {
378       const NamedTypeConstraint &element = getElement(op, i);
379       if (element.name.empty())
380         continue;
381       std::string trailing;
382       if (!element.isVariableLength())
383         trailing = "[0]";
384       else if (element.isOptional())
385         trailing = std::string(
386             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
387       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
388                           kind, i, trailing);
389     }
390     return;
391   }
392 
393   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
394 }
395 
396 /// Free function helpers accessing Operator components.
getNumOperands(const Operator & op)397 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
getOperand(const Operator & op,int i)398 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
399   return op.getOperand(i);
400 }
getNumResults(const Operator & op)401 static int getNumResults(const Operator &op) { return op.getNumResults(); }
getResult(const Operator & op,int i)402 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
403   return op.getResult(i);
404 }
405 
406 /// Emits accessors to Op operands.
emitOperandAccessors(const Operator & op,raw_ostream & os)407 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
408   auto getNumVariadic = [](const Operator &oper) {
409     return oper.getNumVariableLengthOperands();
410   };
411   emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
412                        getOperand);
413 }
414 
415 /// Emits accessors Op results.
emitResultAccessors(const Operator & op,raw_ostream & os)416 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
417   auto getNumVariadic = [](const Operator &oper) {
418     return oper.getNumVariableLengthResults();
419   };
420   emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
421                        getResult);
422 }
423 
424 /// Emits accessors to Op attributes.
emitAttributeAccessors(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)425 static void emitAttributeAccessors(const Operator &op,
426                                    const AttributeClasses &attributeClasses,
427                                    raw_ostream &os) {
428   for (const auto &namedAttr : op.getAttributes()) {
429     // Skip "derived" attributes because they are just C++ functions that we
430     // don't currently expose.
431     if (namedAttr.attr.isDerivedAttr())
432       continue;
433 
434     if (namedAttr.name.empty())
435       continue;
436 
437     std::string sanitizedName = sanitizeName(namedAttr.name);
438 
439     // Unit attributes are handled specially.
440     if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
441       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
442                           namedAttr.name);
443       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
444                           namedAttr.name);
445       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
446                           namedAttr.name);
447       continue;
448     }
449 
450     // Other kinds of attributes need a mapping to a Python type.
451     if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
452       continue;
453 
454     StringRef pythonType =
455         attributeClasses.lookup(namedAttr.attr.getStorageType());
456     if (namedAttr.attr.isOptional()) {
457       os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
458                           pythonType, namedAttr.name);
459       os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
460                           namedAttr.name);
461       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
462                           namedAttr.name);
463     } else {
464       os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
465                           namedAttr.name);
466       os << llvm::formatv(attributeSetterTemplate, sanitizedName,
467                           namedAttr.name);
468       // Non-optional attributes cannot be deleted.
469     }
470   }
471 }
472 
473 /// Template for the default auto-generated builder.
474 ///   {0} is a comma-separated list of builder arguments, including the trailing
475 ///       `loc` and `ip`;
476 ///   {1} is the code populating `operands`, `results` and `attributes` fields.
477 constexpr const char *initTemplate = R"Py(
478   def __init__(self, {0}):
479     operands = []
480     results = []
481     attributes = {{}
482     {1}
483     super().__init__(self.build_generic(
484       attributes=attributes, results=results, operands=operands,
485       loc=loc, ip=ip))
486 )Py";
487 
488 /// Template for appending a single element to the operand/result list.
489 ///   {0} is either 'operand' or 'result';
490 ///   {1} is the field name.
491 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
492 
493 /// Template for appending an optional element to the operand/result list.
494 ///   {0} is either 'operand' or 'result';
495 ///   {1} is the field name.
496 constexpr const char *optionalAppendTemplate =
497     "if {1} is not None: {0}s.append({1})";
498 
499 /// Template for appending a a list of elements to the operand/result list.
500 ///   {0} is either 'operand' or 'result';
501 ///   {1} is the field name.
502 constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})";
503 
504 /// Template for setting an attribute in the operation builder.
505 ///   {0} is the attribute name;
506 ///   {1} is the builder argument name.
507 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
508 
509 /// Template for setting an optional attribute in the operation builder.
510 ///   {0} is the attribute name;
511 ///   {1} is the builder argument name.
512 constexpr const char *initOptionalAttributeTemplate =
513     R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
514 
515 constexpr const char *initUnitAttributeTemplate =
516     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
517       _ods_get_default_loc_context(loc)))Py";
518 
519 /// Populates `builderArgs` with the Python-compatible names of builder function
520 /// arguments, first the results, then the intermixed attributes and operands in
521 /// the same order as they appear in the `arguments` field of the op definition.
522 /// Additionally, `operandNames` is populated with names of operands in their
523 /// order of appearance.
524 static void
populateBuilderArgs(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & operandNames)525 populateBuilderArgs(const Operator &op,
526                     llvm::SmallVectorImpl<std::string> &builderArgs,
527                     llvm::SmallVectorImpl<std::string> &operandNames) {
528   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
529     std::string name = op.getResultName(i).str();
530     if (name.empty()) {
531       if (op.getNumResults() == 1) {
532         // Special case for one result, make the default name be 'result'
533         // to properly match the built-in result accessor.
534         name = "result";
535       } else {
536         name = llvm::formatv("_gen_res_{0}", i);
537       }
538     }
539     name = sanitizeName(name);
540     builderArgs.push_back(name);
541   }
542   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
543     std::string name = op.getArgName(i).str();
544     if (name.empty())
545       name = llvm::formatv("_gen_arg_{0}", i);
546     name = sanitizeName(name);
547     builderArgs.push_back(name);
548     if (!op.getArg(i).is<NamedAttribute *>())
549       operandNames.push_back(name);
550   }
551 }
552 
553 /// Populates `builderLines` with additional lines that are required in the
554 /// builder to set up operation attributes. `argNames` is expected to contain
555 /// the names of builder arguments that correspond to op arguments, i.e. to the
556 /// operands and attributes in the same order as they appear in the `arguments`
557 /// field.
558 static void
populateBuilderLinesAttr(const Operator & op,llvm::ArrayRef<std::string> argNames,llvm::SmallVectorImpl<std::string> & builderLines)559 populateBuilderLinesAttr(const Operator &op,
560                          llvm::ArrayRef<std::string> argNames,
561                          llvm::SmallVectorImpl<std::string> &builderLines) {
562   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
563     Argument arg = op.getArg(i);
564     auto *attribute = arg.dyn_cast<NamedAttribute *>();
565     if (!attribute)
566       continue;
567 
568     // Unit attributes are handled specially.
569     if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
570       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
571                                            attribute->name, argNames[i]));
572       continue;
573     }
574 
575     builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
576                                              ? initOptionalAttributeTemplate
577                                              : initAttributeTemplate,
578                                          attribute->name, argNames[i]));
579   }
580 }
581 
582 /// Populates `builderLines` with additional lines that are required in the
583 /// builder. `kind` must be either "operand" or "result". `names` contains the
584 /// names of init arguments that correspond to the elements.
populateBuilderLines(const Operator & op,const char * kind,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)585 static void populateBuilderLines(
586     const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
587     llvm::SmallVectorImpl<std::string> &builderLines,
588     llvm::function_ref<int(const Operator &)> getNumElements,
589     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
590         getElement) {
591   bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
592 
593   // For each element, find or generate a name.
594   for (int i = 0, e = getNumElements(op); i < e; ++i) {
595     const NamedTypeConstraint &element = getElement(op, i);
596     std::string name = names[i];
597 
598     // Choose the formatting string based on the element kind.
599     llvm::StringRef formatString;
600     if (!element.isVariableLength()) {
601       formatString = singleElementAppendTemplate;
602     } else if (element.isOptional()) {
603       formatString = optionalAppendTemplate;
604     } else {
605       assert(element.isVariadic() && "unhandled element group type");
606       // If emitting with sizedSegments, then we add the actual list typed
607       // element using the singleElementAppendTemplate. Otherwise, we extend
608       // the actual operands.
609       if (sizedSegments) {
610         // Append the list as is.
611         formatString = singleElementAppendTemplate;
612       } else {
613         // Append the list elements.
614         formatString = multiElementAppendTemplate;
615       }
616     }
617 
618     // Add the lines.
619     builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
620   }
621 }
622 
623 /// Emits a default builder constructing an operation from the list of its
624 /// result types, followed by a list of its operands.
emitDefaultOpBuilder(const Operator & op,raw_ostream & os)625 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
626   // If we are asked to skip default builders, comply.
627   if (op.skipDefaultBuilders())
628     return;
629 
630   llvm::SmallVector<std::string, 8> builderArgs;
631   llvm::SmallVector<std::string, 8> builderLines;
632   llvm::SmallVector<std::string, 4> operandArgNames;
633   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
634                       op.getNumNativeAttributes());
635   populateBuilderArgs(op, builderArgs, operandArgNames);
636   populateBuilderLines(
637       op, "result",
638       llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
639       builderLines, getNumResults, getResult);
640   populateBuilderLines(op, "operand", operandArgNames, builderLines,
641                        getNumOperands, getOperand);
642   populateBuilderLinesAttr(
643       op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
644       builderLines);
645 
646   builderArgs.push_back("*");
647   builderArgs.push_back("loc=None");
648   builderArgs.push_back("ip=None");
649   os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),
650                       llvm::join(builderLines, "\n    "));
651 }
652 
constructAttributeMapping(const llvm::RecordKeeper & records,AttributeClasses & attributeClasses)653 static void constructAttributeMapping(const llvm::RecordKeeper &records,
654                                       AttributeClasses &attributeClasses) {
655   for (const llvm::Record *rec :
656        records.getAllDerivedDefinitions("PythonAttr")) {
657     attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
658                                  rec->getValueAsString("pythonType").trim());
659   }
660 }
661 
emitSegmentSpec(const Operator & op,const char * kind,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement,raw_ostream & os)662 static void emitSegmentSpec(
663     const Operator &op, const char *kind,
664     llvm::function_ref<int(const Operator &)> getNumElements,
665     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
666         getElement,
667     raw_ostream &os) {
668   std::string segmentSpec("[");
669   for (int i = 0, e = getNumElements(op); i < e; ++i) {
670     const NamedTypeConstraint &element = getElement(op, i);
671     if (element.isVariableLength()) {
672       segmentSpec.append("-1,");
673     } else if (element.isOptional()) {
674       segmentSpec.append("0,");
675     } else {
676       segmentSpec.append("1,");
677     }
678   }
679   segmentSpec.append("]");
680 
681   os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
682 }
683 
emitRegionAttributes(const Operator & op,raw_ostream & os)684 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
685   // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
686   // Note that the base OpView class defines this as (0, True).
687   unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
688   os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
689                       op.hasNoVariadicRegions() ? "True" : "False");
690 }
691 
692 /// Emits bindings for a specific Op to the given output stream.
emitOpBindings(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)693 static void emitOpBindings(const Operator &op,
694                            const AttributeClasses &attributeClasses,
695                            raw_ostream &os) {
696   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
697                       op.getOperationName());
698 
699   // Sized segments.
700   if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
701     emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
702   }
703   if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
704     emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
705   }
706 
707   emitRegionAttributes(op, os);
708   emitDefaultOpBuilder(op, os);
709   emitOperandAccessors(op, os);
710   emitAttributeAccessors(op, attributeClasses, os);
711   emitResultAccessors(op, os);
712 }
713 
714 /// Emits bindings for the dialect specified in the command line, including file
715 /// headers and utilities. Returns `false` on success to comply with Tablegen
716 /// registration requirements.
emitAllOps(const llvm::RecordKeeper & records,raw_ostream & os)717 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
718   if (clDialectName.empty())
719     llvm::PrintFatalError("dialect name not provided");
720 
721   AttributeClasses attributeClasses;
722   constructAttributeMapping(records, attributeClasses);
723 
724   os << llvm::formatv(fileHeader, clDialectName.getValue());
725   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
726 
727   if (clDialectName == "builtin")
728     clDialectName = "";
729 
730   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
731     Operator op(rec);
732     if (op.getDialectName() == clDialectName.getValue())
733       emitOpBindings(op, attributeClasses, os);
734   }
735   return false;
736 }
737 
738 static GenRegistration
739     genPythonBindings("gen-python-op-bindings",
740                       "Generate Python bindings for MLIR Ops", &emitAllOps);
741