1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21
22 using namespace mlir;
23 using namespace mlir::tblgen;
24
25 /// File header and includes.
26 /// {0} is the dialect namespace.
27 constexpr const char *fileHeader = R"Py(
28 # Autogenerated by mlir-tblgen; don't manually edit.
29
30 from . import _cext as _ods_cext
31 from . import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context
32 _ods_ir = _ods_cext.ir
33
34 try:
35 from . import _{0} as _ods_ext_module
36 except ImportError:
37 _ods_ext_module = None
38
39 )Py";
40
41 /// Template for dialect class:
42 /// {0} is the dialect namespace.
43 constexpr const char *dialectClassTemplate = R"Py(
44 @_ods_cext.register_dialect
45 class _Dialect(_ods_ir.Dialect):
46 DIALECT_NAMESPACE = "{0}"
47 pass
48
49 )Py";
50
51 /// Template for operation class:
52 /// {0} is the Python class name;
53 /// {1} is the operation name.
54 constexpr const char *opClassTemplate = R"Py(
55 @_ods_cext.register_operation(_Dialect)
56 @_ods_extend_opview_class(_ods_ext_module)
57 class {0}(_ods_ir.OpView):
58 OPERATION_NAME = "{1}"
59 )Py";
60
61 /// Template for class level declarations of operand and result
62 /// segment specs.
63 /// {0} is either "OPERAND" or "RESULT"
64 /// {1} is the segment spec
65 /// Each segment spec is either None (default) or an array of integers
66 /// where:
67 /// 1 = single element (expect non sequence operand/result)
68 /// -1 = operand/result is a sequence corresponding to a variadic
69 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
70 _ODS_{0}_SEGMENTS = {1}
71 )Py";
72
73 /// Template for class level declarations of the _ODS_REGIONS spec:
74 /// {0} is the minimum number of regions
75 /// {1} is the Python bool literal for hasNoVariadicRegions
76 constexpr const char *opClassRegionSpecTemplate = R"Py(
77 _ODS_REGIONS = ({0}, {1})
78 )Py";
79
80 /// Template for single-element accessor:
81 /// {0} is the name of the accessor;
82 /// {1} is either 'operand' or 'result';
83 /// {2} is the position in the element list.
84 constexpr const char *opSingleTemplate = R"Py(
85 @property
86 def {0}(self):
87 return self.operation.{1}s[{2}]
88 )Py";
89
90 /// Template for single-element accessor after a variable-length group:
91 /// {0} is the name of the accessor;
92 /// {1} is either 'operand' or 'result';
93 /// {2} is the total number of element groups;
94 /// {3} is the position of the current group in the group list.
95 /// This works for both a single variadic group (non-negative length) and an
96 /// single optional element (zero length if the element is absent).
97 constexpr const char *opSingleAfterVariableTemplate = R"Py(
98 @property
99 def {0}(self):
100 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
101 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
102 )Py";
103
104 /// Template for an optional element accessor:
105 /// {0} is the name of the accessor;
106 /// {1} is either 'operand' or 'result';
107 /// {2} is the total number of element groups;
108 /// {3} is the position of the current group in the group list.
109 constexpr const char *opOneOptionalTemplate = R"Py(
110 @property
111 def {0}(self):
112 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None
113 )Py";
114
115 /// Template for the variadic group accessor in the single variadic group case:
116 /// {0} is the name of the accessor;
117 /// {1} is either 'operand' or 'result';
118 /// {2} is the total number of element groups;
119 /// {3} is the position of the current group in the group list.
120 constexpr const char *opOneVariadicTemplate = R"Py(
121 @property
122 def {0}(self):
123 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
124 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
125 )Py";
126
127 /// First part of the template for equally-sized variadic group accessor:
128 /// {0} is the name of the accessor;
129 /// {1} is either 'operand' or 'result';
130 /// {2} is the total number of variadic groups;
131 /// {3} is the number of non-variadic groups preceding the current group;
132 /// {3} is the number of variadic groups preceding the current group.
133 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
134 @property
135 def {0}(self):
136 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
137
138 /// Second part of the template for equally-sized case, accessing a single
139 /// element:
140 /// {0} is either 'operand' or 'result'.
141 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
142 return self.operation.{0}s[start]
143 )Py";
144
145 /// Second part of the template for equally-sized case, accessing a variadic
146 /// group:
147 /// {0} is either 'operand' or 'result'.
148 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
149 return self.operation.{0}s[start:start + pg]
150 )Py";
151
152 /// Template for an attribute-sized group accessor:
153 /// {0} is the name of the accessor;
154 /// {1} is either 'operand' or 'result';
155 /// {2} is the position of the group in the group list;
156 /// {3} is a return suffix (expected [0] for single-element, empty for
157 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
158 constexpr const char *opVariadicSegmentTemplate = R"Py(
159 @property
160 def {0}(self):
161 {1}_range = _ods_segmented_accessor(
162 self.operation.{1}s,
163 self.operation.attributes["{1}_segment_sizes"], {2})
164 return {1}_range{3}
165 )Py";
166
167 /// Template for a suffix when accessing an optional element in the
168 /// attribute-sized case:
169 /// {0} is either 'operand' or 'result';
170 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
171 R"Py([0] if len({0}_range) > 0 else None)Py";
172
173 /// Template for an operation attribute getter:
174 /// {0} is the name of the attribute sanitized for Python;
175 /// {1} is the Python type of the attribute;
176 /// {2} os the original name of the attribute.
177 constexpr const char *attributeGetterTemplate = R"Py(
178 @property
179 def {0}(self):
180 return {1}(self.operation.attributes["{2}"])
181 )Py";
182
183 /// Template for an optional operation attribute getter:
184 /// {0} is the name of the attribute sanitized for Python;
185 /// {1} is the Python type of the attribute;
186 /// {2} is the original name of the attribute.
187 constexpr const char *optionalAttributeGetterTemplate = R"Py(
188 @property
189 def {0}(self):
190 if "{2}" not in self.operation.attributes:
191 return None
192 return {1}(self.operation.attributes["{2}"])
193 )Py";
194
195 /// Template for a getter of a unit operation attribute, returns True of the
196 /// unit attribute is present, False otherwise (unit attributes have meaning
197 /// by mere presence):
198 /// {0} is the name of the attribute sanitized for Python,
199 /// {1} is the original name of the attribute.
200 constexpr const char *unitAttributeGetterTemplate = R"Py(
201 @property
202 def {0}(self):
203 return "{1}" in self.operation.attributes
204 )Py";
205
206 /// Template for an operation attribute setter:
207 /// {0} is the name of the attribute sanitized for Python;
208 /// {1} is the original name of the attribute.
209 constexpr const char *attributeSetterTemplate = R"Py(
210 @{0}.setter
211 def {0}(self, value):
212 if value is None:
213 raise ValueError("'None' not allowed as value for mandatory attributes")
214 self.operation.attributes["{1}"] = value
215 )Py";
216
217 /// Template for a setter of an optional operation attribute, setting to None
218 /// removes the attribute:
219 /// {0} is the name of the attribute sanitized for Python;
220 /// {1} is the original name of the attribute.
221 constexpr const char *optionalAttributeSetterTemplate = R"Py(
222 @{0}.setter
223 def {0}(self, value):
224 if value is not None:
225 self.operation.attributes["{1}"] = value
226 elif "{1}" in self.operation.attributes:
227 del self.operation.attributes["{1}"]
228 )Py";
229
230 /// Template for a setter of a unit operation attribute, setting to None or
231 /// False removes the attribute:
232 /// {0} is the name of the attribute sanitized for Python;
233 /// {1} is the original name of the attribute.
234 constexpr const char *unitAttributeSetterTemplate = R"Py(
235 @{0}.setter
236 def {0}(self, value):
237 if bool(value):
238 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
239 elif "{1}" in self.operation.attributes:
240 del self.operation.attributes["{1}"]
241 )Py";
242
243 /// Template for a deleter of an optional or a unit operation attribute, removes
244 /// the attribute from the operation:
245 /// {0} is the name of the attribute sanitized for Python;
246 /// {1} is the original name of the attribute.
247 constexpr const char *attributeDeleterTemplate = R"Py(
248 @{0}.deleter
249 def {0}(self):
250 del self.operation.attributes["{1}"]
251 )Py";
252
253 static llvm::cl::OptionCategory
254 clOpPythonBindingCat("Options for -gen-python-op-bindings");
255
256 static llvm::cl::opt<std::string>
257 clDialectName("bind-dialect",
258 llvm::cl::desc("The dialect to run the generator for"),
259 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
260
261 using AttributeClasses = DenseMap<StringRef, StringRef>;
262
263 /// Checks whether `str` is a Python keyword.
isPythonKeyword(StringRef str)264 static bool isPythonKeyword(StringRef str) {
265 static llvm::StringSet<> keywords(
266 {"and", "as", "assert", "break", "class", "continue",
267 "def", "del", "elif", "else", "except", "finally",
268 "for", "from", "global", "if", "import", "in",
269 "is", "lambda", "nonlocal", "not", "or", "pass",
270 "raise", "return", "try", "while", "with", "yield"});
271 return keywords.contains(str);
272 }
273
274 /// Checks whether `str` would shadow a generated variable or attribute
275 /// part of the OpView API.
isODSReserved(StringRef str)276 static bool isODSReserved(StringRef str) {
277 static llvm::StringSet<> reserved(
278 {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
279 "loc", "verify", "regions", "results", "self", "operation",
280 "DIALECT_NAMESPACE", "OPERATION_NAME"});
281 return str.startswith("_ods_") || str.endswith("_ods") ||
282 reserved.contains(str);
283 }
284
285 /// Modifies the `name` in a way that it becomes suitable for Python bindings
286 /// (does not change the `name` if it already is suitable) and returns the
287 /// modified version.
sanitizeName(StringRef name)288 static std::string sanitizeName(StringRef name) {
289 if (isPythonKeyword(name) || isODSReserved(name))
290 return (name + "_").str();
291 return name.str();
292 }
293
attrSizedTraitForKind(const char * kind)294 static std::string attrSizedTraitForKind(const char *kind) {
295 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
296 llvm::StringRef(kind).take_front().upper(),
297 llvm::StringRef(kind).drop_front());
298 }
299
300 /// Emits accessors to "elements" of an Op definition. Currently, the supported
301 /// elements are operands and results, indicated by `kind`, which must be either
302 /// `operand` or `result` and is used verbatim in the emitted code.
emitElementAccessors(const Operator & op,raw_ostream & os,const char * kind,llvm::function_ref<unsigned (const Operator &)> getNumVariadic,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)303 static void emitElementAccessors(
304 const Operator &op, raw_ostream &os, const char *kind,
305 llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
306 llvm::function_ref<int(const Operator &)> getNumElements,
307 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
308 getElement) {
309 assert(llvm::is_contained(
310 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
311 "unsupported kind");
312
313 // Traits indicating how to process variadic elements.
314 std::string sameSizeTrait =
315 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
316 llvm::StringRef(kind).take_front().upper(),
317 llvm::StringRef(kind).drop_front());
318 std::string attrSizedTrait = attrSizedTraitForKind(kind);
319
320 unsigned numVariadic = getNumVariadic(op);
321
322 // If there is only one variadic element group, its size can be inferred from
323 // the total number of elements. If there are none, the generation is
324 // straightforward.
325 if (numVariadic <= 1) {
326 bool seenVariableLength = false;
327 for (int i = 0, e = getNumElements(op); i < e; ++i) {
328 const NamedTypeConstraint &element = getElement(op, i);
329 if (element.isVariableLength())
330 seenVariableLength = true;
331 if (element.name.empty())
332 continue;
333 if (element.isVariableLength()) {
334 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
335 : opOneVariadicTemplate,
336 sanitizeName(element.name), kind,
337 getNumElements(op), i);
338 } else if (seenVariableLength) {
339 os << llvm::formatv(opSingleAfterVariableTemplate,
340 sanitizeName(element.name), kind,
341 getNumElements(op), i);
342 } else {
343 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
344 i);
345 }
346 }
347 return;
348 }
349
350 // Handle the operations where variadic groups have the same size.
351 if (op.getTrait(sameSizeTrait)) {
352 int numPrecedingSimple = 0;
353 int numPrecedingVariadic = 0;
354 for (int i = 0, e = getNumElements(op); i < e; ++i) {
355 const NamedTypeConstraint &element = getElement(op, i);
356 if (!element.name.empty()) {
357 os << llvm::formatv(opVariadicEqualPrefixTemplate,
358 sanitizeName(element.name), kind, numVariadic,
359 numPrecedingSimple, numPrecedingVariadic);
360 os << llvm::formatv(element.isVariableLength()
361 ? opVariadicEqualVariadicTemplate
362 : opVariadicEqualSimpleTemplate,
363 kind);
364 }
365 if (element.isVariableLength())
366 ++numPrecedingVariadic;
367 else
368 ++numPrecedingSimple;
369 }
370 return;
371 }
372
373 // Handle the operations where the size of groups (variadic or not) is
374 // provided as an attribute. For non-variadic elements, make sure to return
375 // an element rather than a singleton container.
376 if (op.getTrait(attrSizedTrait)) {
377 for (int i = 0, e = getNumElements(op); i < e; ++i) {
378 const NamedTypeConstraint &element = getElement(op, i);
379 if (element.name.empty())
380 continue;
381 std::string trailing;
382 if (!element.isVariableLength())
383 trailing = "[0]";
384 else if (element.isOptional())
385 trailing = std::string(
386 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
387 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
388 kind, i, trailing);
389 }
390 return;
391 }
392
393 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
394 }
395
396 /// Free function helpers accessing Operator components.
getNumOperands(const Operator & op)397 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
getOperand(const Operator & op,int i)398 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
399 return op.getOperand(i);
400 }
getNumResults(const Operator & op)401 static int getNumResults(const Operator &op) { return op.getNumResults(); }
getResult(const Operator & op,int i)402 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
403 return op.getResult(i);
404 }
405
406 /// Emits accessors to Op operands.
emitOperandAccessors(const Operator & op,raw_ostream & os)407 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
408 auto getNumVariadic = [](const Operator &oper) {
409 return oper.getNumVariableLengthOperands();
410 };
411 emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
412 getOperand);
413 }
414
415 /// Emits accessors Op results.
emitResultAccessors(const Operator & op,raw_ostream & os)416 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
417 auto getNumVariadic = [](const Operator &oper) {
418 return oper.getNumVariableLengthResults();
419 };
420 emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
421 getResult);
422 }
423
424 /// Emits accessors to Op attributes.
emitAttributeAccessors(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)425 static void emitAttributeAccessors(const Operator &op,
426 const AttributeClasses &attributeClasses,
427 raw_ostream &os) {
428 for (const auto &namedAttr : op.getAttributes()) {
429 // Skip "derived" attributes because they are just C++ functions that we
430 // don't currently expose.
431 if (namedAttr.attr.isDerivedAttr())
432 continue;
433
434 if (namedAttr.name.empty())
435 continue;
436
437 std::string sanitizedName = sanitizeName(namedAttr.name);
438
439 // Unit attributes are handled specially.
440 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
441 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
442 namedAttr.name);
443 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
444 namedAttr.name);
445 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
446 namedAttr.name);
447 continue;
448 }
449
450 // Other kinds of attributes need a mapping to a Python type.
451 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
452 continue;
453
454 StringRef pythonType =
455 attributeClasses.lookup(namedAttr.attr.getStorageType());
456 if (namedAttr.attr.isOptional()) {
457 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
458 pythonType, namedAttr.name);
459 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
460 namedAttr.name);
461 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
462 namedAttr.name);
463 } else {
464 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
465 namedAttr.name);
466 os << llvm::formatv(attributeSetterTemplate, sanitizedName,
467 namedAttr.name);
468 // Non-optional attributes cannot be deleted.
469 }
470 }
471 }
472
473 /// Template for the default auto-generated builder.
474 /// {0} is a comma-separated list of builder arguments, including the trailing
475 /// `loc` and `ip`;
476 /// {1} is the code populating `operands`, `results` and `attributes` fields.
477 constexpr const char *initTemplate = R"Py(
478 def __init__(self, {0}):
479 operands = []
480 results = []
481 attributes = {{}
482 {1}
483 super().__init__(self.build_generic(
484 attributes=attributes, results=results, operands=operands,
485 loc=loc, ip=ip))
486 )Py";
487
488 /// Template for appending a single element to the operand/result list.
489 /// {0} is either 'operand' or 'result';
490 /// {1} is the field name.
491 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
492
493 /// Template for appending an optional element to the operand/result list.
494 /// {0} is either 'operand' or 'result';
495 /// {1} is the field name.
496 constexpr const char *optionalAppendTemplate =
497 "if {1} is not None: {0}s.append({1})";
498
499 /// Template for appending a a list of elements to the operand/result list.
500 /// {0} is either 'operand' or 'result';
501 /// {1} is the field name.
502 constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})";
503
504 /// Template for setting an attribute in the operation builder.
505 /// {0} is the attribute name;
506 /// {1} is the builder argument name.
507 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
508
509 /// Template for setting an optional attribute in the operation builder.
510 /// {0} is the attribute name;
511 /// {1} is the builder argument name.
512 constexpr const char *initOptionalAttributeTemplate =
513 R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
514
515 constexpr const char *initUnitAttributeTemplate =
516 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
517 _ods_get_default_loc_context(loc)))Py";
518
519 /// Populates `builderArgs` with the Python-compatible names of builder function
520 /// arguments, first the results, then the intermixed attributes and operands in
521 /// the same order as they appear in the `arguments` field of the op definition.
522 /// Additionally, `operandNames` is populated with names of operands in their
523 /// order of appearance.
524 static void
populateBuilderArgs(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & operandNames)525 populateBuilderArgs(const Operator &op,
526 llvm::SmallVectorImpl<std::string> &builderArgs,
527 llvm::SmallVectorImpl<std::string> &operandNames) {
528 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
529 std::string name = op.getResultName(i).str();
530 if (name.empty()) {
531 if (op.getNumResults() == 1) {
532 // Special case for one result, make the default name be 'result'
533 // to properly match the built-in result accessor.
534 name = "result";
535 } else {
536 name = llvm::formatv("_gen_res_{0}", i);
537 }
538 }
539 name = sanitizeName(name);
540 builderArgs.push_back(name);
541 }
542 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
543 std::string name = op.getArgName(i).str();
544 if (name.empty())
545 name = llvm::formatv("_gen_arg_{0}", i);
546 name = sanitizeName(name);
547 builderArgs.push_back(name);
548 if (!op.getArg(i).is<NamedAttribute *>())
549 operandNames.push_back(name);
550 }
551 }
552
553 /// Populates `builderLines` with additional lines that are required in the
554 /// builder to set up operation attributes. `argNames` is expected to contain
555 /// the names of builder arguments that correspond to op arguments, i.e. to the
556 /// operands and attributes in the same order as they appear in the `arguments`
557 /// field.
558 static void
populateBuilderLinesAttr(const Operator & op,llvm::ArrayRef<std::string> argNames,llvm::SmallVectorImpl<std::string> & builderLines)559 populateBuilderLinesAttr(const Operator &op,
560 llvm::ArrayRef<std::string> argNames,
561 llvm::SmallVectorImpl<std::string> &builderLines) {
562 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
563 Argument arg = op.getArg(i);
564 auto *attribute = arg.dyn_cast<NamedAttribute *>();
565 if (!attribute)
566 continue;
567
568 // Unit attributes are handled specially.
569 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
570 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
571 attribute->name, argNames[i]));
572 continue;
573 }
574
575 builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
576 ? initOptionalAttributeTemplate
577 : initAttributeTemplate,
578 attribute->name, argNames[i]));
579 }
580 }
581
582 /// Populates `builderLines` with additional lines that are required in the
583 /// builder. `kind` must be either "operand" or "result". `names` contains the
584 /// names of init arguments that correspond to the elements.
populateBuilderLines(const Operator & op,const char * kind,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)585 static void populateBuilderLines(
586 const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
587 llvm::SmallVectorImpl<std::string> &builderLines,
588 llvm::function_ref<int(const Operator &)> getNumElements,
589 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
590 getElement) {
591 bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
592
593 // For each element, find or generate a name.
594 for (int i = 0, e = getNumElements(op); i < e; ++i) {
595 const NamedTypeConstraint &element = getElement(op, i);
596 std::string name = names[i];
597
598 // Choose the formatting string based on the element kind.
599 llvm::StringRef formatString;
600 if (!element.isVariableLength()) {
601 formatString = singleElementAppendTemplate;
602 } else if (element.isOptional()) {
603 formatString = optionalAppendTemplate;
604 } else {
605 assert(element.isVariadic() && "unhandled element group type");
606 // If emitting with sizedSegments, then we add the actual list typed
607 // element using the singleElementAppendTemplate. Otherwise, we extend
608 // the actual operands.
609 if (sizedSegments) {
610 // Append the list as is.
611 formatString = singleElementAppendTemplate;
612 } else {
613 // Append the list elements.
614 formatString = multiElementAppendTemplate;
615 }
616 }
617
618 // Add the lines.
619 builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
620 }
621 }
622
623 /// Emits a default builder constructing an operation from the list of its
624 /// result types, followed by a list of its operands.
emitDefaultOpBuilder(const Operator & op,raw_ostream & os)625 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
626 // If we are asked to skip default builders, comply.
627 if (op.skipDefaultBuilders())
628 return;
629
630 llvm::SmallVector<std::string, 8> builderArgs;
631 llvm::SmallVector<std::string, 8> builderLines;
632 llvm::SmallVector<std::string, 4> operandArgNames;
633 builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
634 op.getNumNativeAttributes());
635 populateBuilderArgs(op, builderArgs, operandArgNames);
636 populateBuilderLines(
637 op, "result",
638 llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
639 builderLines, getNumResults, getResult);
640 populateBuilderLines(op, "operand", operandArgNames, builderLines,
641 getNumOperands, getOperand);
642 populateBuilderLinesAttr(
643 op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
644 builderLines);
645
646 builderArgs.push_back("*");
647 builderArgs.push_back("loc=None");
648 builderArgs.push_back("ip=None");
649 os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),
650 llvm::join(builderLines, "\n "));
651 }
652
constructAttributeMapping(const llvm::RecordKeeper & records,AttributeClasses & attributeClasses)653 static void constructAttributeMapping(const llvm::RecordKeeper &records,
654 AttributeClasses &attributeClasses) {
655 for (const llvm::Record *rec :
656 records.getAllDerivedDefinitions("PythonAttr")) {
657 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
658 rec->getValueAsString("pythonType").trim());
659 }
660 }
661
emitSegmentSpec(const Operator & op,const char * kind,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement,raw_ostream & os)662 static void emitSegmentSpec(
663 const Operator &op, const char *kind,
664 llvm::function_ref<int(const Operator &)> getNumElements,
665 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
666 getElement,
667 raw_ostream &os) {
668 std::string segmentSpec("[");
669 for (int i = 0, e = getNumElements(op); i < e; ++i) {
670 const NamedTypeConstraint &element = getElement(op, i);
671 if (element.isVariableLength()) {
672 segmentSpec.append("-1,");
673 } else if (element.isOptional()) {
674 segmentSpec.append("0,");
675 } else {
676 segmentSpec.append("1,");
677 }
678 }
679 segmentSpec.append("]");
680
681 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
682 }
683
emitRegionAttributes(const Operator & op,raw_ostream & os)684 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
685 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
686 // Note that the base OpView class defines this as (0, True).
687 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
688 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
689 op.hasNoVariadicRegions() ? "True" : "False");
690 }
691
692 /// Emits bindings for a specific Op to the given output stream.
emitOpBindings(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)693 static void emitOpBindings(const Operator &op,
694 const AttributeClasses &attributeClasses,
695 raw_ostream &os) {
696 os << llvm::formatv(opClassTemplate, op.getCppClassName(),
697 op.getOperationName());
698
699 // Sized segments.
700 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
701 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
702 }
703 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
704 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
705 }
706
707 emitRegionAttributes(op, os);
708 emitDefaultOpBuilder(op, os);
709 emitOperandAccessors(op, os);
710 emitAttributeAccessors(op, attributeClasses, os);
711 emitResultAccessors(op, os);
712 }
713
714 /// Emits bindings for the dialect specified in the command line, including file
715 /// headers and utilities. Returns `false` on success to comply with Tablegen
716 /// registration requirements.
emitAllOps(const llvm::RecordKeeper & records,raw_ostream & os)717 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
718 if (clDialectName.empty())
719 llvm::PrintFatalError("dialect name not provided");
720
721 AttributeClasses attributeClasses;
722 constructAttributeMapping(records, attributeClasses);
723
724 os << llvm::formatv(fileHeader, clDialectName.getValue());
725 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
726
727 if (clDialectName == "builtin")
728 clDialectName = "";
729
730 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
731 Operator op(rec);
732 if (op.getDialectName() == clDialectName.getValue())
733 emitOpBindings(op, attributeClasses, os);
734 }
735 return false;
736 }
737
738 static GenRegistration
739 genPythonBindings("gen-python-op-bindings",
740 "Generate Python bindings for MLIR Ops", &emitAllOps);
741