1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "IRModule.h"
10
11 #include "Globals.h"
12 #include "PybindUtils.h"
13
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34
35 static const char kContextParseTypeDocstring[] =
36 R"(Parses the assembly form of a type.
37
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42
43 static const char kContextGetFileLocationDocstring[] =
44 R"(Gets a Location representing a file, line and column)";
45
46 static const char kModuleParseDocstring[] =
47 R"(Parses a module's assembly format from a string.
48
49 Returns a new MlirModule or raises a ValueError if the parsing fails.
50
51 See also: https://mlir.llvm.org/docs/LangRef/
52 )";
53
54 static const char kOperationCreateDocstring[] =
55 R"(Creates a new operation.
56
57 Args:
58 name: Operation name (e.g. "dialect.operation").
59 results: Sequence of Type representing op result types.
60 attributes: Dict of str:Attribute.
61 successors: List of Block for the operation's successors.
62 regions: Number of regions to create.
63 location: A Location object (defaults to resolve from context manager).
64 ip: An InsertionPoint (defaults to resolve from context manager or set to
65 False to disable insertion, even with an insertion point set in the
66 context manager).
67 Returns:
68 A new "detached" Operation object. Detached operations can be added
69 to blocks, which causes them to become "attached."
70 )";
71
72 static const char kOperationPrintDocstring[] =
73 R"(Prints the assembly form of the operation to a file like object.
74
75 Args:
76 file: The file like object to write to. Defaults to sys.stdout.
77 binary: Whether to write bytes (True) or str (False). Defaults to False.
78 large_elements_limit: Whether to elide elements attributes above this
79 number of elements. Defaults to None (no limit).
80 enable_debug_info: Whether to print debug/location information. Defaults
81 to False.
82 pretty_debug_info: Whether to format debug information for easier reading
83 by a human (warning: the result is unparseable).
84 print_generic_op_form: Whether to print the generic assembly forms of all
85 ops. Defaults to False.
86 use_local_Scope: Whether to print in a way that is more optimized for
87 multi-threaded access but may not be consistent with how the overall
88 module prints.
89 )";
90
91 static const char kOperationGetAsmDocstring[] =
92 R"(Gets the assembly form of the operation with all options available.
93
94 Args:
95 binary: Whether to return a bytes (True) or str (False) object. Defaults to
96 False.
97 ... others ...: See the print() method for common keyword arguments for
98 configuring the printout.
99 Returns:
100 Either a bytes or str object, depending on the setting of the 'binary'
101 argument.
102 )";
103
104 static const char kOperationStrDunderDocstring[] =
105 R"(Gets the assembly form of the operation with default options.
106
107 If more advanced control over the assembly formatting or I/O options is needed,
108 use the dedicated print or get_asm method, which supports keyword arguments to
109 customize behavior.
110 )";
111
112 static const char kDumpDocstring[] =
113 R"(Dumps a debug representation of the object to stderr.)";
114
115 static const char kAppendBlockDocstring[] =
116 R"(Appends a new block, with argument types as positional args.
117
118 Returns:
119 The created block.
120 )";
121
122 static const char kValueDunderStrDocstring[] =
123 R"(Returns the string form of the value.
124
125 If the value is a block argument, this is the assembly form of its type and the
126 position in the argument list. If the value is an operation result, this is
127 equivalent to printing the operation that produced it.
128 )";
129
130 //------------------------------------------------------------------------------
131 // Utilities.
132 //------------------------------------------------------------------------------
133
134 /// Helper for creating an @classmethod.
135 template <class Func, typename... Args>
classmethod(Func f,Args...args)136 py::object classmethod(Func f, Args... args) {
137 py::object cf = py::cpp_function(f, args...);
138 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139 }
140
141 static py::object
createCustomDialectWrapper(const std::string & dialectNamespace,py::object dialectDescriptor)142 createCustomDialectWrapper(const std::string &dialectNamespace,
143 py::object dialectDescriptor) {
144 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
145 if (!dialectClass) {
146 // Use the base class.
147 return py::cast(PyDialect(std::move(dialectDescriptor)));
148 }
149
150 // Create the custom implementation.
151 return (*dialectClass)(std::move(dialectDescriptor));
152 }
153
toMlirStringRef(const std::string & s)154 static MlirStringRef toMlirStringRef(const std::string &s) {
155 return mlirStringRefCreate(s.data(), s.size());
156 }
157
158 /// Wrapper for the global LLVM debugging flag.
159 struct PyGlobalDebugFlag {
setPyGlobalDebugFlag160 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
161
getPyGlobalDebugFlag162 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
163
bindPyGlobalDebugFlag164 static void bind(py::module &m) {
165 // Debug flags.
166 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
167 .def_property_static("flag", &PyGlobalDebugFlag::get,
168 &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
169 }
170 };
171
172 //------------------------------------------------------------------------------
173 // Collections.
174 //------------------------------------------------------------------------------
175
176 namespace {
177
178 class PyRegionIterator {
179 public:
PyRegionIterator(PyOperationRef operation)180 PyRegionIterator(PyOperationRef operation)
181 : operation(std::move(operation)) {}
182
dunderIter()183 PyRegionIterator &dunderIter() { return *this; }
184
dunderNext()185 PyRegion dunderNext() {
186 operation->checkValid();
187 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
188 throw py::stop_iteration();
189 }
190 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
191 return PyRegion(operation, region);
192 }
193
bind(py::module & m)194 static void bind(py::module &m) {
195 py::class_<PyRegionIterator>(m, "RegionIterator")
196 .def("__iter__", &PyRegionIterator::dunderIter)
197 .def("__next__", &PyRegionIterator::dunderNext);
198 }
199
200 private:
201 PyOperationRef operation;
202 int nextIndex = 0;
203 };
204
205 /// Regions of an op are fixed length and indexed numerically so are represented
206 /// with a sequence-like container.
207 class PyRegionList {
208 public:
PyRegionList(PyOperationRef operation)209 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
210
dunderLen()211 intptr_t dunderLen() {
212 operation->checkValid();
213 return mlirOperationGetNumRegions(operation->get());
214 }
215
dunderGetItem(intptr_t index)216 PyRegion dunderGetItem(intptr_t index) {
217 // dunderLen checks validity.
218 if (index < 0 || index >= dunderLen()) {
219 throw SetPyError(PyExc_IndexError,
220 "attempt to access out of bounds region");
221 }
222 MlirRegion region = mlirOperationGetRegion(operation->get(), index);
223 return PyRegion(operation, region);
224 }
225
bind(py::module & m)226 static void bind(py::module &m) {
227 py::class_<PyRegionList>(m, "RegionSequence")
228 .def("__len__", &PyRegionList::dunderLen)
229 .def("__getitem__", &PyRegionList::dunderGetItem);
230 }
231
232 private:
233 PyOperationRef operation;
234 };
235
236 class PyBlockIterator {
237 public:
PyBlockIterator(PyOperationRef operation,MlirBlock next)238 PyBlockIterator(PyOperationRef operation, MlirBlock next)
239 : operation(std::move(operation)), next(next) {}
240
dunderIter()241 PyBlockIterator &dunderIter() { return *this; }
242
dunderNext()243 PyBlock dunderNext() {
244 operation->checkValid();
245 if (mlirBlockIsNull(next)) {
246 throw py::stop_iteration();
247 }
248
249 PyBlock returnBlock(operation, next);
250 next = mlirBlockGetNextInRegion(next);
251 return returnBlock;
252 }
253
bind(py::module & m)254 static void bind(py::module &m) {
255 py::class_<PyBlockIterator>(m, "BlockIterator")
256 .def("__iter__", &PyBlockIterator::dunderIter)
257 .def("__next__", &PyBlockIterator::dunderNext);
258 }
259
260 private:
261 PyOperationRef operation;
262 MlirBlock next;
263 };
264
265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
266 /// we present them as a more full-featured list-like container but optimize
267 /// it for forward iteration. Blocks are always owned by a region.
268 class PyBlockList {
269 public:
PyBlockList(PyOperationRef operation,MlirRegion region)270 PyBlockList(PyOperationRef operation, MlirRegion region)
271 : operation(std::move(operation)), region(region) {}
272
dunderIter()273 PyBlockIterator dunderIter() {
274 operation->checkValid();
275 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
276 }
277
dunderLen()278 intptr_t dunderLen() {
279 operation->checkValid();
280 intptr_t count = 0;
281 MlirBlock block = mlirRegionGetFirstBlock(region);
282 while (!mlirBlockIsNull(block)) {
283 count += 1;
284 block = mlirBlockGetNextInRegion(block);
285 }
286 return count;
287 }
288
dunderGetItem(intptr_t index)289 PyBlock dunderGetItem(intptr_t index) {
290 operation->checkValid();
291 if (index < 0) {
292 throw SetPyError(PyExc_IndexError,
293 "attempt to access out of bounds block");
294 }
295 MlirBlock block = mlirRegionGetFirstBlock(region);
296 while (!mlirBlockIsNull(block)) {
297 if (index == 0) {
298 return PyBlock(operation, block);
299 }
300 block = mlirBlockGetNextInRegion(block);
301 index -= 1;
302 }
303 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
304 }
305
appendBlock(py::args pyArgTypes)306 PyBlock appendBlock(py::args pyArgTypes) {
307 operation->checkValid();
308 llvm::SmallVector<MlirType, 4> argTypes;
309 argTypes.reserve(pyArgTypes.size());
310 for (auto &pyArg : pyArgTypes) {
311 argTypes.push_back(pyArg.cast<PyType &>());
312 }
313
314 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
315 mlirRegionAppendOwnedBlock(region, block);
316 return PyBlock(operation, block);
317 }
318
bind(py::module & m)319 static void bind(py::module &m) {
320 py::class_<PyBlockList>(m, "BlockList")
321 .def("__getitem__", &PyBlockList::dunderGetItem)
322 .def("__iter__", &PyBlockList::dunderIter)
323 .def("__len__", &PyBlockList::dunderLen)
324 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
325 }
326
327 private:
328 PyOperationRef operation;
329 MlirRegion region;
330 };
331
332 class PyOperationIterator {
333 public:
PyOperationIterator(PyOperationRef parentOperation,MlirOperation next)334 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
335 : parentOperation(std::move(parentOperation)), next(next) {}
336
dunderIter()337 PyOperationIterator &dunderIter() { return *this; }
338
dunderNext()339 py::object dunderNext() {
340 parentOperation->checkValid();
341 if (mlirOperationIsNull(next)) {
342 throw py::stop_iteration();
343 }
344
345 PyOperationRef returnOperation =
346 PyOperation::forOperation(parentOperation->getContext(), next);
347 next = mlirOperationGetNextInBlock(next);
348 return returnOperation->createOpView();
349 }
350
bind(py::module & m)351 static void bind(py::module &m) {
352 py::class_<PyOperationIterator>(m, "OperationIterator")
353 .def("__iter__", &PyOperationIterator::dunderIter)
354 .def("__next__", &PyOperationIterator::dunderNext);
355 }
356
357 private:
358 PyOperationRef parentOperation;
359 MlirOperation next;
360 };
361
362 /// Operations are exposed by the C-API as a forward-only linked list. In
363 /// Python, we present them as a more full-featured list-like container but
364 /// optimize it for forward iteration. Iterable operations are always owned
365 /// by a block.
366 class PyOperationList {
367 public:
PyOperationList(PyOperationRef parentOperation,MlirBlock block)368 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
369 : parentOperation(std::move(parentOperation)), block(block) {}
370
dunderIter()371 PyOperationIterator dunderIter() {
372 parentOperation->checkValid();
373 return PyOperationIterator(parentOperation,
374 mlirBlockGetFirstOperation(block));
375 }
376
dunderLen()377 intptr_t dunderLen() {
378 parentOperation->checkValid();
379 intptr_t count = 0;
380 MlirOperation childOp = mlirBlockGetFirstOperation(block);
381 while (!mlirOperationIsNull(childOp)) {
382 count += 1;
383 childOp = mlirOperationGetNextInBlock(childOp);
384 }
385 return count;
386 }
387
dunderGetItem(intptr_t index)388 py::object dunderGetItem(intptr_t index) {
389 parentOperation->checkValid();
390 if (index < 0) {
391 throw SetPyError(PyExc_IndexError,
392 "attempt to access out of bounds operation");
393 }
394 MlirOperation childOp = mlirBlockGetFirstOperation(block);
395 while (!mlirOperationIsNull(childOp)) {
396 if (index == 0) {
397 return PyOperation::forOperation(parentOperation->getContext(), childOp)
398 ->createOpView();
399 }
400 childOp = mlirOperationGetNextInBlock(childOp);
401 index -= 1;
402 }
403 throw SetPyError(PyExc_IndexError,
404 "attempt to access out of bounds operation");
405 }
406
bind(py::module & m)407 static void bind(py::module &m) {
408 py::class_<PyOperationList>(m, "OperationList")
409 .def("__getitem__", &PyOperationList::dunderGetItem)
410 .def("__iter__", &PyOperationList::dunderIter)
411 .def("__len__", &PyOperationList::dunderLen);
412 }
413
414 private:
415 PyOperationRef parentOperation;
416 MlirBlock block;
417 };
418
419 } // namespace
420
421 //------------------------------------------------------------------------------
422 // PyMlirContext
423 //------------------------------------------------------------------------------
424
PyMlirContext(MlirContext context)425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
426 py::gil_scoped_acquire acquire;
427 auto &liveContexts = getLiveContexts();
428 liveContexts[context.ptr] = this;
429 }
430
~PyMlirContext()431 PyMlirContext::~PyMlirContext() {
432 // Note that the only public way to construct an instance is via the
433 // forContext method, which always puts the associated handle into
434 // liveContexts.
435 py::gil_scoped_acquire acquire;
436 getLiveContexts().erase(context.ptr);
437 mlirContextDestroy(context);
438 }
439
getCapsule()440 py::object PyMlirContext::getCapsule() {
441 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
442 }
443
createFromCapsule(py::object capsule)444 py::object PyMlirContext::createFromCapsule(py::object capsule) {
445 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
446 if (mlirContextIsNull(rawContext))
447 throw py::error_already_set();
448 return forContext(rawContext).releaseObject();
449 }
450
createNewContextForInit()451 PyMlirContext *PyMlirContext::createNewContextForInit() {
452 MlirContext context = mlirContextCreate();
453 mlirRegisterAllDialects(context);
454 return new PyMlirContext(context);
455 }
456
forContext(MlirContext context)457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
458 py::gil_scoped_acquire acquire;
459 auto &liveContexts = getLiveContexts();
460 auto it = liveContexts.find(context.ptr);
461 if (it == liveContexts.end()) {
462 // Create.
463 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
464 py::object pyRef = py::cast(unownedContextWrapper);
465 assert(pyRef && "cast to py::object failed");
466 liveContexts[context.ptr] = unownedContextWrapper;
467 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
468 }
469 // Use existing.
470 py::object pyRef = py::cast(it->second);
471 return PyMlirContextRef(it->second, std::move(pyRef));
472 }
473
getLiveContexts()474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
475 static LiveContextMap liveContexts;
476 return liveContexts;
477 }
478
getLiveCount()479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
480
getLiveOperationCount()481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
482
getLiveModuleCount()483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
484
contextEnter()485 pybind11::object PyMlirContext::contextEnter() {
486 return PyThreadContextEntry::pushContext(*this);
487 }
488
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)489 void PyMlirContext::contextExit(pybind11::object excType,
490 pybind11::object excVal,
491 pybind11::object excTb) {
492 PyThreadContextEntry::popContext(*this);
493 }
494
resolve()495 PyMlirContext &DefaultingPyMlirContext::resolve() {
496 PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
497 if (!context) {
498 throw SetPyError(
499 PyExc_RuntimeError,
500 "An MLIR function requires a Context but none was provided in the call "
501 "or from the surrounding environment. Either pass to the function with "
502 "a 'context=' argument or establish a default using 'with Context():'");
503 }
504 return *context;
505 }
506
507 //------------------------------------------------------------------------------
508 // PyThreadContextEntry management
509 //------------------------------------------------------------------------------
510
getStack()511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
512 static thread_local std::vector<PyThreadContextEntry> stack;
513 return stack;
514 }
515
getTopOfStack()516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
517 auto &stack = getStack();
518 if (stack.empty())
519 return nullptr;
520 return &stack.back();
521 }
522
push(FrameKind frameKind,py::object context,py::object insertionPoint,py::object location)523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
524 py::object insertionPoint,
525 py::object location) {
526 auto &stack = getStack();
527 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
528 std::move(location));
529 // If the new stack has more than one entry and the context of the new top
530 // entry matches the previous, copy the insertionPoint and location from the
531 // previous entry if missing from the new top entry.
532 if (stack.size() > 1) {
533 auto &prev = *(stack.rbegin() + 1);
534 auto ¤t = stack.back();
535 if (current.context.is(prev.context)) {
536 // Default non-context objects from the previous entry.
537 if (!current.insertionPoint)
538 current.insertionPoint = prev.insertionPoint;
539 if (!current.location)
540 current.location = prev.location;
541 }
542 }
543 }
544
getContext()545 PyMlirContext *PyThreadContextEntry::getContext() {
546 if (!context)
547 return nullptr;
548 return py::cast<PyMlirContext *>(context);
549 }
550
getInsertionPoint()551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
552 if (!insertionPoint)
553 return nullptr;
554 return py::cast<PyInsertionPoint *>(insertionPoint);
555 }
556
getLocation()557 PyLocation *PyThreadContextEntry::getLocation() {
558 if (!location)
559 return nullptr;
560 return py::cast<PyLocation *>(location);
561 }
562
getDefaultContext()563 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
564 auto *tos = getTopOfStack();
565 return tos ? tos->getContext() : nullptr;
566 }
567
getDefaultInsertionPoint()568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
569 auto *tos = getTopOfStack();
570 return tos ? tos->getInsertionPoint() : nullptr;
571 }
572
getDefaultLocation()573 PyLocation *PyThreadContextEntry::getDefaultLocation() {
574 auto *tos = getTopOfStack();
575 return tos ? tos->getLocation() : nullptr;
576 }
577
pushContext(PyMlirContext & context)578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
579 py::object contextObj = py::cast(context);
580 push(FrameKind::Context, /*context=*/contextObj,
581 /*insertionPoint=*/py::object(),
582 /*location=*/py::object());
583 return contextObj;
584 }
585
popContext(PyMlirContext & context)586 void PyThreadContextEntry::popContext(PyMlirContext &context) {
587 auto &stack = getStack();
588 if (stack.empty())
589 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
590 auto &tos = stack.back();
591 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
592 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
593 stack.pop_back();
594 }
595
596 py::object
pushInsertionPoint(PyInsertionPoint & insertionPoint)597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
598 py::object contextObj =
599 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
600 py::object insertionPointObj = py::cast(insertionPoint);
601 push(FrameKind::InsertionPoint,
602 /*context=*/contextObj,
603 /*insertionPoint=*/insertionPointObj,
604 /*location=*/py::object());
605 return insertionPointObj;
606 }
607
popInsertionPoint(PyInsertionPoint & insertionPoint)608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
609 auto &stack = getStack();
610 if (stack.empty())
611 throw SetPyError(PyExc_RuntimeError,
612 "Unbalanced InsertionPoint enter/exit");
613 auto &tos = stack.back();
614 if (tos.frameKind != FrameKind::InsertionPoint &&
615 tos.getInsertionPoint() != &insertionPoint)
616 throw SetPyError(PyExc_RuntimeError,
617 "Unbalanced InsertionPoint enter/exit");
618 stack.pop_back();
619 }
620
pushLocation(PyLocation & location)621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
622 py::object contextObj = location.getContext().getObject();
623 py::object locationObj = py::cast(location);
624 push(FrameKind::Location, /*context=*/contextObj,
625 /*insertionPoint=*/py::object(),
626 /*location=*/locationObj);
627 return locationObj;
628 }
629
popLocation(PyLocation & location)630 void PyThreadContextEntry::popLocation(PyLocation &location) {
631 auto &stack = getStack();
632 if (stack.empty())
633 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
634 auto &tos = stack.back();
635 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
636 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
637 stack.pop_back();
638 }
639
640 //------------------------------------------------------------------------------
641 // PyDialect, PyDialectDescriptor, PyDialects
642 //------------------------------------------------------------------------------
643
getDialectForKey(const std::string & key,bool attrError)644 MlirDialect PyDialects::getDialectForKey(const std::string &key,
645 bool attrError) {
646 // If the "std" dialect was asked for, substitute the empty namespace :(
647 static const std::string emptyKey;
648 const std::string *canonKey = key == "std" ? &emptyKey : &key;
649 MlirDialect dialect = mlirContextGetOrLoadDialect(
650 getContext()->get(), {canonKey->data(), canonKey->size()});
651 if (mlirDialectIsNull(dialect)) {
652 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
653 Twine("Dialect '") + key + "' not found");
654 }
655 return dialect;
656 }
657
658 //------------------------------------------------------------------------------
659 // PyLocation
660 //------------------------------------------------------------------------------
661
getCapsule()662 py::object PyLocation::getCapsule() {
663 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
664 }
665
createFromCapsule(py::object capsule)666 PyLocation PyLocation::createFromCapsule(py::object capsule) {
667 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
668 if (mlirLocationIsNull(rawLoc))
669 throw py::error_already_set();
670 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
671 rawLoc);
672 }
673
contextEnter()674 py::object PyLocation::contextEnter() {
675 return PyThreadContextEntry::pushLocation(*this);
676 }
677
contextExit(py::object excType,py::object excVal,py::object excTb)678 void PyLocation::contextExit(py::object excType, py::object excVal,
679 py::object excTb) {
680 PyThreadContextEntry::popLocation(*this);
681 }
682
resolve()683 PyLocation &DefaultingPyLocation::resolve() {
684 auto *location = PyThreadContextEntry::getDefaultLocation();
685 if (!location) {
686 throw SetPyError(
687 PyExc_RuntimeError,
688 "An MLIR function requires a Location but none was provided in the "
689 "call or from the surrounding environment. Either pass to the function "
690 "with a 'loc=' argument or establish a default using 'with loc:'");
691 }
692 return *location;
693 }
694
695 //------------------------------------------------------------------------------
696 // PyModule
697 //------------------------------------------------------------------------------
698
PyModule(PyMlirContextRef contextRef,MlirModule module)699 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
700 : BaseContextObject(std::move(contextRef)), module(module) {}
701
~PyModule()702 PyModule::~PyModule() {
703 py::gil_scoped_acquire acquire;
704 auto &liveModules = getContext()->liveModules;
705 assert(liveModules.count(module.ptr) == 1 &&
706 "destroying module not in live map");
707 liveModules.erase(module.ptr);
708 mlirModuleDestroy(module);
709 }
710
forModule(MlirModule module)711 PyModuleRef PyModule::forModule(MlirModule module) {
712 MlirContext context = mlirModuleGetContext(module);
713 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
714
715 py::gil_scoped_acquire acquire;
716 auto &liveModules = contextRef->liveModules;
717 auto it = liveModules.find(module.ptr);
718 if (it == liveModules.end()) {
719 // Create.
720 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
721 // Note that the default return value policy on cast is automatic_reference,
722 // which does not take ownership (delete will not be called).
723 // Just be explicit.
724 py::object pyRef =
725 py::cast(unownedModule, py::return_value_policy::take_ownership);
726 unownedModule->handle = pyRef;
727 liveModules[module.ptr] =
728 std::make_pair(unownedModule->handle, unownedModule);
729 return PyModuleRef(unownedModule, std::move(pyRef));
730 }
731 // Use existing.
732 PyModule *existing = it->second.second;
733 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
734 return PyModuleRef(existing, std::move(pyRef));
735 }
736
createFromCapsule(py::object capsule)737 py::object PyModule::createFromCapsule(py::object capsule) {
738 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
739 if (mlirModuleIsNull(rawModule))
740 throw py::error_already_set();
741 return forModule(rawModule).releaseObject();
742 }
743
getCapsule()744 py::object PyModule::getCapsule() {
745 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
746 }
747
748 //------------------------------------------------------------------------------
749 // PyOperation
750 //------------------------------------------------------------------------------
751
PyOperation(PyMlirContextRef contextRef,MlirOperation operation)752 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
753 : BaseContextObject(std::move(contextRef)), operation(operation) {}
754
~PyOperation()755 PyOperation::~PyOperation() {
756 // If the operation has already been invalidated there is nothing to do.
757 if (!valid)
758 return;
759 auto &liveOperations = getContext()->liveOperations;
760 assert(liveOperations.count(operation.ptr) == 1 &&
761 "destroying operation not in live map");
762 liveOperations.erase(operation.ptr);
763 if (!isAttached()) {
764 mlirOperationDestroy(operation);
765 }
766 }
767
createInstance(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)768 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
769 MlirOperation operation,
770 py::object parentKeepAlive) {
771 auto &liveOperations = contextRef->liveOperations;
772 // Create.
773 PyOperation *unownedOperation =
774 new PyOperation(std::move(contextRef), operation);
775 // Note that the default return value policy on cast is automatic_reference,
776 // which does not take ownership (delete will not be called).
777 // Just be explicit.
778 py::object pyRef =
779 py::cast(unownedOperation, py::return_value_policy::take_ownership);
780 unownedOperation->handle = pyRef;
781 if (parentKeepAlive) {
782 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
783 }
784 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
785 return PyOperationRef(unownedOperation, std::move(pyRef));
786 }
787
forOperation(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)788 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
789 MlirOperation operation,
790 py::object parentKeepAlive) {
791 auto &liveOperations = contextRef->liveOperations;
792 auto it = liveOperations.find(operation.ptr);
793 if (it == liveOperations.end()) {
794 // Create.
795 return createInstance(std::move(contextRef), operation,
796 std::move(parentKeepAlive));
797 }
798 // Use existing.
799 PyOperation *existing = it->second.second;
800 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
801 return PyOperationRef(existing, std::move(pyRef));
802 }
803
createDetached(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)804 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
805 MlirOperation operation,
806 py::object parentKeepAlive) {
807 auto &liveOperations = contextRef->liveOperations;
808 assert(liveOperations.count(operation.ptr) == 0 &&
809 "cannot create detached operation that already exists");
810 (void)liveOperations;
811
812 PyOperationRef created = createInstance(std::move(contextRef), operation,
813 std::move(parentKeepAlive));
814 created->attached = false;
815 return created;
816 }
817
checkValid() const818 void PyOperation::checkValid() const {
819 if (!valid) {
820 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
821 }
822 }
823
print(py::object fileObject,bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)824 void PyOperationBase::print(py::object fileObject, bool binary,
825 llvm::Optional<int64_t> largeElementsLimit,
826 bool enableDebugInfo, bool prettyDebugInfo,
827 bool printGenericOpForm, bool useLocalScope) {
828 PyOperation &operation = getOperation();
829 operation.checkValid();
830 if (fileObject.is_none())
831 fileObject = py::module::import("sys").attr("stdout");
832
833 if (!printGenericOpForm && !mlirOperationVerify(operation)) {
834 fileObject.attr("write")("// Verification failed, printing generic form\n");
835 printGenericOpForm = true;
836 }
837
838 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
839 if (largeElementsLimit)
840 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
841 if (enableDebugInfo)
842 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
843 if (printGenericOpForm)
844 mlirOpPrintingFlagsPrintGenericOpForm(flags);
845
846 PyFileAccumulator accum(fileObject, binary);
847 py::gil_scoped_release();
848 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
849 accum.getUserData());
850 mlirOpPrintingFlagsDestroy(flags);
851 }
852
getAsm(bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)853 py::object PyOperationBase::getAsm(bool binary,
854 llvm::Optional<int64_t> largeElementsLimit,
855 bool enableDebugInfo, bool prettyDebugInfo,
856 bool printGenericOpForm,
857 bool useLocalScope) {
858 py::object fileObject;
859 if (binary) {
860 fileObject = py::module::import("io").attr("BytesIO")();
861 } else {
862 fileObject = py::module::import("io").attr("StringIO")();
863 }
864 print(fileObject, /*binary=*/binary,
865 /*largeElementsLimit=*/largeElementsLimit,
866 /*enableDebugInfo=*/enableDebugInfo,
867 /*prettyDebugInfo=*/prettyDebugInfo,
868 /*printGenericOpForm=*/printGenericOpForm,
869 /*useLocalScope=*/useLocalScope);
870
871 return fileObject.attr("getvalue")();
872 }
873
getParentOperation()874 PyOperationRef PyOperation::getParentOperation() {
875 checkValid();
876 if (!isAttached())
877 throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
878 MlirOperation operation = mlirOperationGetParentOperation(get());
879 if (mlirOperationIsNull(operation))
880 throw SetPyError(PyExc_ValueError, "Operation has no parent.");
881 return PyOperation::forOperation(getContext(), operation);
882 }
883
getBlock()884 PyBlock PyOperation::getBlock() {
885 checkValid();
886 PyOperationRef parentOperation = getParentOperation();
887 MlirBlock block = mlirOperationGetBlock(get());
888 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
889 return PyBlock{std::move(parentOperation), block};
890 }
891
getCapsule()892 py::object PyOperation::getCapsule() {
893 checkValid();
894 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
895 }
896
createFromCapsule(py::object capsule)897 py::object PyOperation::createFromCapsule(py::object capsule) {
898 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
899 if (mlirOperationIsNull(rawOperation))
900 throw py::error_already_set();
901 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
902 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
903 .releaseObject();
904 }
905
create(std::string name,llvm::Optional<std::vector<PyType * >> results,llvm::Optional<std::vector<PyValue * >> operands,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,int regions,DefaultingPyLocation location,py::object maybeIp)906 py::object PyOperation::create(
907 std::string name, llvm::Optional<std::vector<PyType *>> results,
908 llvm::Optional<std::vector<PyValue *>> operands,
909 llvm::Optional<py::dict> attributes,
910 llvm::Optional<std::vector<PyBlock *>> successors, int regions,
911 DefaultingPyLocation location, py::object maybeIp) {
912 llvm::SmallVector<MlirValue, 4> mlirOperands;
913 llvm::SmallVector<MlirType, 4> mlirResults;
914 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
915 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
916
917 // General parameter validation.
918 if (regions < 0)
919 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
920
921 // Unpack/validate operands.
922 if (operands) {
923 mlirOperands.reserve(operands->size());
924 for (PyValue *operand : *operands) {
925 if (!operand)
926 throw SetPyError(PyExc_ValueError, "operand value cannot be None");
927 mlirOperands.push_back(operand->get());
928 }
929 }
930
931 // Unpack/validate results.
932 if (results) {
933 mlirResults.reserve(results->size());
934 for (PyType *result : *results) {
935 // TODO: Verify result type originate from the same context.
936 if (!result)
937 throw SetPyError(PyExc_ValueError, "result type cannot be None");
938 mlirResults.push_back(*result);
939 }
940 }
941 // Unpack/validate attributes.
942 if (attributes) {
943 mlirAttributes.reserve(attributes->size());
944 for (auto &it : *attributes) {
945 std::string key;
946 try {
947 key = it.first.cast<std::string>();
948 } catch (py::cast_error &err) {
949 std::string msg = "Invalid attribute key (not a string) when "
950 "attempting to create the operation \"" +
951 name + "\" (" + err.what() + ")";
952 throw py::cast_error(msg);
953 }
954 try {
955 auto &attribute = it.second.cast<PyAttribute &>();
956 // TODO: Verify attribute originates from the same context.
957 mlirAttributes.emplace_back(std::move(key), attribute);
958 } catch (py::reference_cast_error &) {
959 // This exception seems thrown when the value is "None".
960 std::string msg =
961 "Found an invalid (`None`?) attribute value for the key \"" + key +
962 "\" when attempting to create the operation \"" + name + "\"";
963 throw py::cast_error(msg);
964 } catch (py::cast_error &err) {
965 std::string msg = "Invalid attribute value for the key \"" + key +
966 "\" when attempting to create the operation \"" +
967 name + "\" (" + err.what() + ")";
968 throw py::cast_error(msg);
969 }
970 }
971 }
972 // Unpack/validate successors.
973 if (successors) {
974 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
975 mlirSuccessors.reserve(successors->size());
976 for (auto *successor : *successors) {
977 // TODO: Verify successor originate from the same context.
978 if (!successor)
979 throw SetPyError(PyExc_ValueError, "successor block cannot be None");
980 mlirSuccessors.push_back(successor->get());
981 }
982 }
983
984 // Apply unpacked/validated to the operation state. Beyond this
985 // point, exceptions cannot be thrown or else the state will leak.
986 MlirOperationState state =
987 mlirOperationStateGet(toMlirStringRef(name), location);
988 if (!mlirOperands.empty())
989 mlirOperationStateAddOperands(&state, mlirOperands.size(),
990 mlirOperands.data());
991 if (!mlirResults.empty())
992 mlirOperationStateAddResults(&state, mlirResults.size(),
993 mlirResults.data());
994 if (!mlirAttributes.empty()) {
995 // Note that the attribute names directly reference bytes in
996 // mlirAttributes, so that vector must not be changed from here
997 // on.
998 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
999 mlirNamedAttributes.reserve(mlirAttributes.size());
1000 for (auto &it : mlirAttributes)
1001 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1002 mlirIdentifierGet(mlirAttributeGetContext(it.second),
1003 toMlirStringRef(it.first)),
1004 it.second));
1005 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1006 mlirNamedAttributes.data());
1007 }
1008 if (!mlirSuccessors.empty())
1009 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1010 mlirSuccessors.data());
1011 if (regions) {
1012 llvm::SmallVector<MlirRegion, 4> mlirRegions;
1013 mlirRegions.resize(regions);
1014 for (int i = 0; i < regions; ++i)
1015 mlirRegions[i] = mlirRegionCreate();
1016 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1017 mlirRegions.data());
1018 }
1019
1020 // Construct the operation.
1021 MlirOperation operation = mlirOperationCreate(&state);
1022 PyOperationRef created =
1023 PyOperation::createDetached(location->getContext(), operation);
1024
1025 // InsertPoint active?
1026 if (!maybeIp.is(py::cast(false))) {
1027 PyInsertionPoint *ip;
1028 if (maybeIp.is_none()) {
1029 ip = PyThreadContextEntry::getDefaultInsertionPoint();
1030 } else {
1031 ip = py::cast<PyInsertionPoint *>(maybeIp);
1032 }
1033 if (ip)
1034 ip->insert(*created.get());
1035 }
1036
1037 return created->createOpView();
1038 }
1039
createOpView()1040 py::object PyOperation::createOpView() {
1041 checkValid();
1042 MlirIdentifier ident = mlirOperationGetName(get());
1043 MlirStringRef identStr = mlirIdentifierStr(ident);
1044 auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1045 StringRef(identStr.data, identStr.length));
1046 if (opViewClass)
1047 return (*opViewClass)(getRef().getObject());
1048 return py::cast(PyOpView(getRef().getObject()));
1049 }
1050
erase()1051 void PyOperation::erase() {
1052 checkValid();
1053 // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1054 // Python reference to a child operation is live. All children should also
1055 // have their `valid` bit set to false.
1056 auto &liveOperations = getContext()->liveOperations;
1057 if (liveOperations.count(operation.ptr))
1058 liveOperations.erase(operation.ptr);
1059 mlirOperationDestroy(operation);
1060 valid = false;
1061 }
1062
1063 //------------------------------------------------------------------------------
1064 // PyOpView
1065 //------------------------------------------------------------------------------
1066
1067 py::object
buildGeneric(py::object cls,py::list resultTypeList,py::list operandList,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,llvm::Optional<int> regions,DefaultingPyLocation location,py::object maybeIp)1068 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1069 py::list operandList,
1070 llvm::Optional<py::dict> attributes,
1071 llvm::Optional<std::vector<PyBlock *>> successors,
1072 llvm::Optional<int> regions,
1073 DefaultingPyLocation location, py::object maybeIp) {
1074 PyMlirContextRef context = location->getContext();
1075 // Class level operation construction metadata.
1076 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1077 // Operand and result segment specs are either none, which does no
1078 // variadic unpacking, or a list of ints with segment sizes, where each
1079 // element is either a positive number (typically 1 for a scalar) or -1 to
1080 // indicate that it is derived from the length of the same-indexed operand
1081 // or result (implying that it is a list at that position).
1082 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1083 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1084
1085 std::vector<uint32_t> operandSegmentLengths;
1086 std::vector<uint32_t> resultSegmentLengths;
1087
1088 // Validate/determine region count.
1089 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1090 int opMinRegionCount = std::get<0>(opRegionSpec);
1091 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1092 if (!regions) {
1093 regions = opMinRegionCount;
1094 }
1095 if (*regions < opMinRegionCount) {
1096 throw py::value_error(
1097 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1098 llvm::Twine(opMinRegionCount) +
1099 " regions but was built with regions=" + llvm::Twine(*regions))
1100 .str());
1101 }
1102 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1103 throw py::value_error(
1104 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1105 llvm::Twine(opMinRegionCount) +
1106 " regions but was built with regions=" + llvm::Twine(*regions))
1107 .str());
1108 }
1109
1110 // Unpack results.
1111 std::vector<PyType *> resultTypes;
1112 resultTypes.reserve(resultTypeList.size());
1113 if (resultSegmentSpecObj.is_none()) {
1114 // Non-variadic result unpacking.
1115 for (auto it : llvm::enumerate(resultTypeList)) {
1116 try {
1117 resultTypes.push_back(py::cast<PyType *>(it.value()));
1118 if (!resultTypes.back())
1119 throw py::cast_error();
1120 } catch (py::cast_error &err) {
1121 throw py::value_error((llvm::Twine("Result ") +
1122 llvm::Twine(it.index()) + " of operation \"" +
1123 name + "\" must be a Type (" + err.what() + ")")
1124 .str());
1125 }
1126 }
1127 } else {
1128 // Sized result unpacking.
1129 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1130 if (resultSegmentSpec.size() != resultTypeList.size()) {
1131 throw py::value_error((llvm::Twine("Operation \"") + name +
1132 "\" requires " +
1133 llvm::Twine(resultSegmentSpec.size()) +
1134 "result segments but was provided " +
1135 llvm::Twine(resultTypeList.size()))
1136 .str());
1137 }
1138 resultSegmentLengths.reserve(resultTypeList.size());
1139 for (auto it :
1140 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1141 int segmentSpec = std::get<1>(it.value());
1142 if (segmentSpec == 1 || segmentSpec == 0) {
1143 // Unpack unary element.
1144 try {
1145 auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1146 if (resultType) {
1147 resultTypes.push_back(resultType);
1148 resultSegmentLengths.push_back(1);
1149 } else if (segmentSpec == 0) {
1150 // Allowed to be optional.
1151 resultSegmentLengths.push_back(0);
1152 } else {
1153 throw py::cast_error("was None and result is not optional");
1154 }
1155 } catch (py::cast_error &err) {
1156 throw py::value_error((llvm::Twine("Result ") +
1157 llvm::Twine(it.index()) + " of operation \"" +
1158 name + "\" must be a Type (" + err.what() +
1159 ")")
1160 .str());
1161 }
1162 } else if (segmentSpec == -1) {
1163 // Unpack sequence by appending.
1164 try {
1165 if (std::get<0>(it.value()).is_none()) {
1166 // Treat it as an empty list.
1167 resultSegmentLengths.push_back(0);
1168 } else {
1169 // Unpack the list.
1170 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1171 for (py::object segmentItem : segment) {
1172 resultTypes.push_back(py::cast<PyType *>(segmentItem));
1173 if (!resultTypes.back()) {
1174 throw py::cast_error("contained a None item");
1175 }
1176 }
1177 resultSegmentLengths.push_back(segment.size());
1178 }
1179 } catch (std::exception &err) {
1180 // NOTE: Sloppy to be using a catch-all here, but there are at least
1181 // three different unrelated exceptions that can be thrown in the
1182 // above "casts". Just keep the scope above small and catch them all.
1183 throw py::value_error((llvm::Twine("Result ") +
1184 llvm::Twine(it.index()) + " of operation \"" +
1185 name + "\" must be a Sequence of Types (" +
1186 err.what() + ")")
1187 .str());
1188 }
1189 } else {
1190 throw py::value_error("Unexpected segment spec");
1191 }
1192 }
1193 }
1194
1195 // Unpack operands.
1196 std::vector<PyValue *> operands;
1197 operands.reserve(operands.size());
1198 if (operandSegmentSpecObj.is_none()) {
1199 // Non-sized operand unpacking.
1200 for (auto it : llvm::enumerate(operandList)) {
1201 try {
1202 operands.push_back(py::cast<PyValue *>(it.value()));
1203 if (!operands.back())
1204 throw py::cast_error();
1205 } catch (py::cast_error &err) {
1206 throw py::value_error((llvm::Twine("Operand ") +
1207 llvm::Twine(it.index()) + " of operation \"" +
1208 name + "\" must be a Value (" + err.what() + ")")
1209 .str());
1210 }
1211 }
1212 } else {
1213 // Sized operand unpacking.
1214 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1215 if (operandSegmentSpec.size() != operandList.size()) {
1216 throw py::value_error((llvm::Twine("Operation \"") + name +
1217 "\" requires " +
1218 llvm::Twine(operandSegmentSpec.size()) +
1219 "operand segments but was provided " +
1220 llvm::Twine(operandList.size()))
1221 .str());
1222 }
1223 operandSegmentLengths.reserve(operandList.size());
1224 for (auto it :
1225 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1226 int segmentSpec = std::get<1>(it.value());
1227 if (segmentSpec == 1 || segmentSpec == 0) {
1228 // Unpack unary element.
1229 try {
1230 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1231 if (operandValue) {
1232 operands.push_back(operandValue);
1233 operandSegmentLengths.push_back(1);
1234 } else if (segmentSpec == 0) {
1235 // Allowed to be optional.
1236 operandSegmentLengths.push_back(0);
1237 } else {
1238 throw py::cast_error("was None and operand is not optional");
1239 }
1240 } catch (py::cast_error &err) {
1241 throw py::value_error((llvm::Twine("Operand ") +
1242 llvm::Twine(it.index()) + " of operation \"" +
1243 name + "\" must be a Value (" + err.what() +
1244 ")")
1245 .str());
1246 }
1247 } else if (segmentSpec == -1) {
1248 // Unpack sequence by appending.
1249 try {
1250 if (std::get<0>(it.value()).is_none()) {
1251 // Treat it as an empty list.
1252 operandSegmentLengths.push_back(0);
1253 } else {
1254 // Unpack the list.
1255 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1256 for (py::object segmentItem : segment) {
1257 operands.push_back(py::cast<PyValue *>(segmentItem));
1258 if (!operands.back()) {
1259 throw py::cast_error("contained a None item");
1260 }
1261 }
1262 operandSegmentLengths.push_back(segment.size());
1263 }
1264 } catch (std::exception &err) {
1265 // NOTE: Sloppy to be using a catch-all here, but there are at least
1266 // three different unrelated exceptions that can be thrown in the
1267 // above "casts". Just keep the scope above small and catch them all.
1268 throw py::value_error((llvm::Twine("Operand ") +
1269 llvm::Twine(it.index()) + " of operation \"" +
1270 name + "\" must be a Sequence of Values (" +
1271 err.what() + ")")
1272 .str());
1273 }
1274 } else {
1275 throw py::value_error("Unexpected segment spec");
1276 }
1277 }
1278 }
1279
1280 // Merge operand/result segment lengths into attributes if needed.
1281 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1282 // Dup.
1283 if (attributes) {
1284 attributes = py::dict(*attributes);
1285 } else {
1286 attributes = py::dict();
1287 }
1288 if (attributes->contains("result_segment_sizes") ||
1289 attributes->contains("operand_segment_sizes")) {
1290 throw py::value_error("Manually setting a 'result_segment_sizes' or "
1291 "'operand_segment_sizes' attribute is unsupported. "
1292 "Use Operation.create for such low-level access.");
1293 }
1294
1295 // Add result_segment_sizes attribute.
1296 if (!resultSegmentLengths.empty()) {
1297 int64_t size = resultSegmentLengths.size();
1298 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1299 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1300 resultSegmentLengths.size(), resultSegmentLengths.data());
1301 (*attributes)["result_segment_sizes"] =
1302 PyAttribute(context, segmentLengthAttr);
1303 }
1304
1305 // Add operand_segment_sizes attribute.
1306 if (!operandSegmentLengths.empty()) {
1307 int64_t size = operandSegmentLengths.size();
1308 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1309 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1310 operandSegmentLengths.size(), operandSegmentLengths.data());
1311 (*attributes)["operand_segment_sizes"] =
1312 PyAttribute(context, segmentLengthAttr);
1313 }
1314 }
1315
1316 // Delegate to create.
1317 return PyOperation::create(std::move(name),
1318 /*results=*/std::move(resultTypes),
1319 /*operands=*/std::move(operands),
1320 /*attributes=*/std::move(attributes),
1321 /*successors=*/std::move(successors),
1322 /*regions=*/*regions, location, maybeIp);
1323 }
1324
PyOpView(py::object operationObject)1325 PyOpView::PyOpView(py::object operationObject)
1326 // Casting through the PyOperationBase base-class and then back to the
1327 // Operation lets us accept any PyOperationBase subclass.
1328 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1329 operationObject(operation.getRef().getObject()) {}
1330
createRawSubclass(py::object userClass)1331 py::object PyOpView::createRawSubclass(py::object userClass) {
1332 // This is... a little gross. The typical pattern is to have a pure python
1333 // class that extends OpView like:
1334 // class AddFOp(_cext.ir.OpView):
1335 // def __init__(self, loc, lhs, rhs):
1336 // operation = loc.context.create_operation(
1337 // "addf", lhs, rhs, results=[lhs.type])
1338 // super().__init__(operation)
1339 //
1340 // I.e. The goal of the user facing type is to provide a nice constructor
1341 // that has complete freedom for the op under construction. This is at odds
1342 // with our other desire to sometimes create this object by just passing an
1343 // operation (to initialize the base class). We could do *arg and **kwargs
1344 // munging to try to make it work, but instead, we synthesize a new class
1345 // on the fly which extends this user class (AddFOp in this example) and
1346 // *give it* the base class's __init__ method, thus bypassing the
1347 // intermediate subclass's __init__ method entirely. While slightly,
1348 // underhanded, this is safe/legal because the type hierarchy has not changed
1349 // (we just added a new leaf) and we aren't mucking around with __new__.
1350 // Typically, this new class will be stored on the original as "_Raw" and will
1351 // be used for casts and other things that need a variant of the class that
1352 // is initialized purely from an operation.
1353 py::object parentMetaclass =
1354 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1355 py::dict attributes;
1356 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1357 // now.
1358 // auto opViewType = py::type::of<PyOpView>();
1359 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1360 attributes["__init__"] = opViewType.attr("__init__");
1361 py::str origName = userClass.attr("__name__");
1362 py::str newName = py::str("_") + origName;
1363 return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1364 }
1365
1366 //------------------------------------------------------------------------------
1367 // PyInsertionPoint.
1368 //------------------------------------------------------------------------------
1369
PyInsertionPoint(PyBlock & block)1370 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1371
PyInsertionPoint(PyOperationBase & beforeOperationBase)1372 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1373 : refOperation(beforeOperationBase.getOperation().getRef()),
1374 block((*refOperation)->getBlock()) {}
1375
insert(PyOperationBase & operationBase)1376 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1377 PyOperation &operation = operationBase.getOperation();
1378 if (operation.isAttached())
1379 throw SetPyError(PyExc_ValueError,
1380 "Attempt to insert operation that is already attached");
1381 block.getParentOperation()->checkValid();
1382 MlirOperation beforeOp = {nullptr};
1383 if (refOperation) {
1384 // Insert before operation.
1385 (*refOperation)->checkValid();
1386 beforeOp = (*refOperation)->get();
1387 } else {
1388 // Insert at end (before null) is only valid if the block does not
1389 // already end in a known terminator (violating this will cause assertion
1390 // failures later).
1391 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1392 throw py::index_error("Cannot insert operation at the end of a block "
1393 "that already has a terminator. Did you mean to "
1394 "use 'InsertionPoint.at_block_terminator(block)' "
1395 "versus 'InsertionPoint(block)'?");
1396 }
1397 }
1398 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1399 operation.setAttached();
1400 }
1401
atBlockBegin(PyBlock & block)1402 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1403 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1404 if (mlirOperationIsNull(firstOp)) {
1405 // Just insert at end.
1406 return PyInsertionPoint(block);
1407 }
1408
1409 // Insert before first op.
1410 PyOperationRef firstOpRef = PyOperation::forOperation(
1411 block.getParentOperation()->getContext(), firstOp);
1412 return PyInsertionPoint{block, std::move(firstOpRef)};
1413 }
1414
atBlockTerminator(PyBlock & block)1415 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1416 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1417 if (mlirOperationIsNull(terminator))
1418 throw SetPyError(PyExc_ValueError, "Block has no terminator");
1419 PyOperationRef terminatorOpRef = PyOperation::forOperation(
1420 block.getParentOperation()->getContext(), terminator);
1421 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1422 }
1423
contextEnter()1424 py::object PyInsertionPoint::contextEnter() {
1425 return PyThreadContextEntry::pushInsertionPoint(*this);
1426 }
1427
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)1428 void PyInsertionPoint::contextExit(pybind11::object excType,
1429 pybind11::object excVal,
1430 pybind11::object excTb) {
1431 PyThreadContextEntry::popInsertionPoint(*this);
1432 }
1433
1434 //------------------------------------------------------------------------------
1435 // PyAttribute.
1436 //------------------------------------------------------------------------------
1437
operator ==(const PyAttribute & other)1438 bool PyAttribute::operator==(const PyAttribute &other) {
1439 return mlirAttributeEqual(attr, other.attr);
1440 }
1441
getCapsule()1442 py::object PyAttribute::getCapsule() {
1443 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1444 }
1445
createFromCapsule(py::object capsule)1446 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1447 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1448 if (mlirAttributeIsNull(rawAttr))
1449 throw py::error_already_set();
1450 return PyAttribute(
1451 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1452 }
1453
1454 //------------------------------------------------------------------------------
1455 // PyNamedAttribute.
1456 //------------------------------------------------------------------------------
1457
PyNamedAttribute(MlirAttribute attr,std::string ownedName)1458 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1459 : ownedName(new std::string(std::move(ownedName))) {
1460 namedAttr = mlirNamedAttributeGet(
1461 mlirIdentifierGet(mlirAttributeGetContext(attr),
1462 toMlirStringRef(*this->ownedName)),
1463 attr);
1464 }
1465
1466 //------------------------------------------------------------------------------
1467 // PyType.
1468 //------------------------------------------------------------------------------
1469
operator ==(const PyType & other)1470 bool PyType::operator==(const PyType &other) {
1471 return mlirTypeEqual(type, other.type);
1472 }
1473
getCapsule()1474 py::object PyType::getCapsule() {
1475 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1476 }
1477
createFromCapsule(py::object capsule)1478 PyType PyType::createFromCapsule(py::object capsule) {
1479 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1480 if (mlirTypeIsNull(rawType))
1481 throw py::error_already_set();
1482 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1483 rawType);
1484 }
1485
1486 //------------------------------------------------------------------------------
1487 // PyValue and subclases.
1488 //------------------------------------------------------------------------------
1489
getCapsule()1490 pybind11::object PyValue::getCapsule() {
1491 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1492 }
1493
createFromCapsule(pybind11::object capsule)1494 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1495 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1496 if (mlirValueIsNull(value))
1497 throw py::error_already_set();
1498 MlirOperation owner;
1499 if (mlirValueIsAOpResult(value))
1500 owner = mlirOpResultGetOwner(value);
1501 if (mlirValueIsABlockArgument(value))
1502 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1503 if (mlirOperationIsNull(owner))
1504 throw py::error_already_set();
1505 MlirContext ctx = mlirOperationGetContext(owner);
1506 PyOperationRef ownerRef =
1507 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1508 return PyValue(ownerRef, value);
1509 }
1510
1511 namespace {
1512 /// CRTP base class for Python MLIR values that subclass Value and should be
1513 /// castable from it. The value hierarchy is one level deep and is not supposed
1514 /// to accommodate other levels unless core MLIR changes.
1515 template <typename DerivedTy>
1516 class PyConcreteValue : public PyValue {
1517 public:
1518 // Derived classes must define statics for:
1519 // IsAFunctionTy isaFunction
1520 // const char *pyClassName
1521 // and redefine bindDerived.
1522 using ClassTy = py::class_<DerivedTy, PyValue>;
1523 using IsAFunctionTy = bool (*)(MlirValue);
1524
1525 PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef,MlirValue value)1526 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1527 : PyValue(operationRef, value) {}
PyConcreteValue(PyValue & orig)1528 PyConcreteValue(PyValue &orig)
1529 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1530
1531 /// Attempts to cast the original value to the derived type and throws on
1532 /// type mismatches.
castFrom(PyValue & orig)1533 static MlirValue castFrom(PyValue &orig) {
1534 if (!DerivedTy::isaFunction(orig.get())) {
1535 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1536 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1537 DerivedTy::pyClassName +
1538 " (from " + origRepr + ")");
1539 }
1540 return orig.get();
1541 }
1542
1543 /// Binds the Python module objects to functions of this class.
bind(py::module & m)1544 static void bind(py::module &m) {
1545 auto cls = ClassTy(m, DerivedTy::pyClassName);
1546 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1547 DerivedTy::bindDerived(cls);
1548 }
1549
1550 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1551 static void bindDerived(ClassTy &m) {}
1552 };
1553
1554 /// Python wrapper for MlirBlockArgument.
1555 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1556 public:
1557 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1558 static constexpr const char *pyClassName = "BlockArgument";
1559 using PyConcreteValue::PyConcreteValue;
1560
bindDerived(ClassTy & c)1561 static void bindDerived(ClassTy &c) {
1562 c.def_property_readonly("owner", [](PyBlockArgument &self) {
1563 return PyBlock(self.getParentOperation(),
1564 mlirBlockArgumentGetOwner(self.get()));
1565 });
1566 c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1567 return mlirBlockArgumentGetArgNumber(self.get());
1568 });
1569 c.def("set_type", [](PyBlockArgument &self, PyType type) {
1570 return mlirBlockArgumentSetType(self.get(), type);
1571 });
1572 }
1573 };
1574
1575 /// Python wrapper for MlirOpResult.
1576 class PyOpResult : public PyConcreteValue<PyOpResult> {
1577 public:
1578 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1579 static constexpr const char *pyClassName = "OpResult";
1580 using PyConcreteValue::PyConcreteValue;
1581
bindDerived(ClassTy & c)1582 static void bindDerived(ClassTy &c) {
1583 c.def_property_readonly("owner", [](PyOpResult &self) {
1584 assert(
1585 mlirOperationEqual(self.getParentOperation()->get(),
1586 mlirOpResultGetOwner(self.get())) &&
1587 "expected the owner of the value in Python to match that in the IR");
1588 return self.getParentOperation().getObject();
1589 });
1590 c.def_property_readonly("result_number", [](PyOpResult &self) {
1591 return mlirOpResultGetResultNumber(self.get());
1592 });
1593 }
1594 };
1595
1596 /// A list of block arguments. Internally, these are stored as consecutive
1597 /// elements, random access is cheap. The argument list is associated with the
1598 /// operation that contains the block (detached blocks are not allowed in
1599 /// Python bindings) and extends its lifetime.
1600 class PyBlockArgumentList {
1601 public:
PyBlockArgumentList(PyOperationRef operation,MlirBlock block)1602 PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1603 : operation(std::move(operation)), block(block) {}
1604
1605 /// Returns the length of the block argument list.
dunderLen()1606 intptr_t dunderLen() {
1607 operation->checkValid();
1608 return mlirBlockGetNumArguments(block);
1609 }
1610
1611 /// Returns `index`-th element of the block argument list.
dunderGetItem(intptr_t index)1612 PyBlockArgument dunderGetItem(intptr_t index) {
1613 if (index < 0 || index >= dunderLen()) {
1614 throw SetPyError(PyExc_IndexError,
1615 "attempt to access out of bounds region");
1616 }
1617 PyValue value(operation, mlirBlockGetArgument(block, index));
1618 return PyBlockArgument(value);
1619 }
1620
1621 /// Defines a Python class in the bindings.
bind(py::module & m)1622 static void bind(py::module &m) {
1623 py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1624 .def("__len__", &PyBlockArgumentList::dunderLen)
1625 .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1626 }
1627
1628 private:
1629 PyOperationRef operation;
1630 MlirBlock block;
1631 };
1632
1633 /// A list of operation operands. Internally, these are stored as consecutive
1634 /// elements, random access is cheap. The result list is associated with the
1635 /// operation whose results these are, and extends the lifetime of this
1636 /// operation.
1637 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1638 public:
1639 static constexpr const char *pyClassName = "OpOperandList";
1640
PyOpOperandList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1641 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1642 intptr_t length = -1, intptr_t step = 1)
1643 : Sliceable(startIndex,
1644 length == -1 ? mlirOperationGetNumOperands(operation->get())
1645 : length,
1646 step),
1647 operation(operation) {}
1648
getNumElements()1649 intptr_t getNumElements() {
1650 operation->checkValid();
1651 return mlirOperationGetNumOperands(operation->get());
1652 }
1653
getElement(intptr_t pos)1654 PyValue getElement(intptr_t pos) {
1655 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1656 MlirOperation owner;
1657 if (mlirValueIsAOpResult(operand))
1658 owner = mlirOpResultGetOwner(operand);
1659 else if (mlirValueIsABlockArgument(operand))
1660 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1661 else
1662 assert(false && "Value must be an block arg or op result.");
1663 PyOperationRef pyOwner =
1664 PyOperation::forOperation(operation->getContext(), owner);
1665 return PyValue(pyOwner, operand);
1666 }
1667
slice(intptr_t startIndex,intptr_t length,intptr_t step)1668 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1669 return PyOpOperandList(operation, startIndex, length, step);
1670 }
1671
dunderSetItem(intptr_t index,PyValue value)1672 void dunderSetItem(intptr_t index, PyValue value) {
1673 index = wrapIndex(index);
1674 mlirOperationSetOperand(operation->get(), index, value.get());
1675 }
1676
bindDerived(ClassTy & c)1677 static void bindDerived(ClassTy &c) {
1678 c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1679 }
1680
1681 private:
1682 PyOperationRef operation;
1683 };
1684
1685 /// A list of operation results. Internally, these are stored as consecutive
1686 /// elements, random access is cheap. The result list is associated with the
1687 /// operation whose results these are, and extends the lifetime of this
1688 /// operation.
1689 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1690 public:
1691 static constexpr const char *pyClassName = "OpResultList";
1692
PyOpResultList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1693 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1694 intptr_t length = -1, intptr_t step = 1)
1695 : Sliceable(startIndex,
1696 length == -1 ? mlirOperationGetNumResults(operation->get())
1697 : length,
1698 step),
1699 operation(operation) {}
1700
getNumElements()1701 intptr_t getNumElements() {
1702 operation->checkValid();
1703 return mlirOperationGetNumResults(operation->get());
1704 }
1705
getElement(intptr_t index)1706 PyOpResult getElement(intptr_t index) {
1707 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1708 return PyOpResult(value);
1709 }
1710
slice(intptr_t startIndex,intptr_t length,intptr_t step)1711 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1712 return PyOpResultList(operation, startIndex, length, step);
1713 }
1714
1715 private:
1716 PyOperationRef operation;
1717 };
1718
1719 /// A list of operation attributes. Can be indexed by name, producing
1720 /// attributes, or by index, producing named attributes.
1721 class PyOpAttributeMap {
1722 public:
PyOpAttributeMap(PyOperationRef operation)1723 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1724
dunderGetItemNamed(const std::string & name)1725 PyAttribute dunderGetItemNamed(const std::string &name) {
1726 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1727 toMlirStringRef(name));
1728 if (mlirAttributeIsNull(attr)) {
1729 throw SetPyError(PyExc_KeyError,
1730 "attempt to access a non-existent attribute");
1731 }
1732 return PyAttribute(operation->getContext(), attr);
1733 }
1734
dunderGetItemIndexed(intptr_t index)1735 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1736 if (index < 0 || index >= dunderLen()) {
1737 throw SetPyError(PyExc_IndexError,
1738 "attempt to access out of bounds attribute");
1739 }
1740 MlirNamedAttribute namedAttr =
1741 mlirOperationGetAttribute(operation->get(), index);
1742 return PyNamedAttribute(
1743 namedAttr.attribute,
1744 std::string(mlirIdentifierStr(namedAttr.name).data));
1745 }
1746
dunderSetItem(const std::string & name,PyAttribute attr)1747 void dunderSetItem(const std::string &name, PyAttribute attr) {
1748 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1749 attr);
1750 }
1751
dunderDelItem(const std::string & name)1752 void dunderDelItem(const std::string &name) {
1753 int removed = mlirOperationRemoveAttributeByName(operation->get(),
1754 toMlirStringRef(name));
1755 if (!removed)
1756 throw SetPyError(PyExc_KeyError,
1757 "attempt to delete a non-existent attribute");
1758 }
1759
dunderLen()1760 intptr_t dunderLen() {
1761 return mlirOperationGetNumAttributes(operation->get());
1762 }
1763
dunderContains(const std::string & name)1764 bool dunderContains(const std::string &name) {
1765 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1766 operation->get(), toMlirStringRef(name)));
1767 }
1768
bind(py::module & m)1769 static void bind(py::module &m) {
1770 py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1771 .def("__contains__", &PyOpAttributeMap::dunderContains)
1772 .def("__len__", &PyOpAttributeMap::dunderLen)
1773 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1774 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1775 .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1776 .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1777 }
1778
1779 private:
1780 PyOperationRef operation;
1781 };
1782
1783 } // end namespace
1784
1785 //------------------------------------------------------------------------------
1786 // Populates the core exports of the 'ir' submodule.
1787 //------------------------------------------------------------------------------
1788
populateIRCore(py::module & m)1789 void mlir::python::populateIRCore(py::module &m) {
1790 //----------------------------------------------------------------------------
1791 // Mapping of MlirContext.
1792 //----------------------------------------------------------------------------
1793 py::class_<PyMlirContext>(m, "Context")
1794 .def(py::init<>(&PyMlirContext::createNewContextForInit))
1795 .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1796 .def("_get_context_again",
1797 [](PyMlirContext &self) {
1798 PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1799 return ref.releaseObject();
1800 })
1801 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1802 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1803 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1804 &PyMlirContext::getCapsule)
1805 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1806 .def("__enter__", &PyMlirContext::contextEnter)
1807 .def("__exit__", &PyMlirContext::contextExit)
1808 .def_property_readonly_static(
1809 "current",
1810 [](py::object & /*class*/) {
1811 auto *context = PyThreadContextEntry::getDefaultContext();
1812 if (!context)
1813 throw SetPyError(PyExc_ValueError, "No current Context");
1814 return context;
1815 },
1816 "Gets the Context bound to the current thread or raises ValueError")
1817 .def_property_readonly(
1818 "dialects",
1819 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1820 "Gets a container for accessing dialects by name")
1821 .def_property_readonly(
1822 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1823 "Alias for 'dialect'")
1824 .def(
1825 "get_dialect_descriptor",
1826 [=](PyMlirContext &self, std::string &name) {
1827 MlirDialect dialect = mlirContextGetOrLoadDialect(
1828 self.get(), {name.data(), name.size()});
1829 if (mlirDialectIsNull(dialect)) {
1830 throw SetPyError(PyExc_ValueError,
1831 Twine("Dialect '") + name + "' not found");
1832 }
1833 return PyDialectDescriptor(self.getRef(), dialect);
1834 },
1835 "Gets or loads a dialect by name, returning its descriptor object")
1836 .def_property(
1837 "allow_unregistered_dialects",
1838 [](PyMlirContext &self) -> bool {
1839 return mlirContextGetAllowUnregisteredDialects(self.get());
1840 },
1841 [](PyMlirContext &self, bool value) {
1842 mlirContextSetAllowUnregisteredDialects(self.get(), value);
1843 })
1844 .def("enable_multithreading",
1845 [](PyMlirContext &self, bool enable) {
1846 mlirContextEnableMultithreading(self.get(), enable);
1847 })
1848 .def("is_registered_operation",
1849 [](PyMlirContext &self, std::string &name) {
1850 return mlirContextIsRegisteredOperation(
1851 self.get(), MlirStringRef{name.data(), name.size()});
1852 });
1853
1854 //----------------------------------------------------------------------------
1855 // Mapping of PyDialectDescriptor
1856 //----------------------------------------------------------------------------
1857 py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
1858 .def_property_readonly("namespace",
1859 [](PyDialectDescriptor &self) {
1860 MlirStringRef ns =
1861 mlirDialectGetNamespace(self.get());
1862 return py::str(ns.data, ns.length);
1863 })
1864 .def("__repr__", [](PyDialectDescriptor &self) {
1865 MlirStringRef ns = mlirDialectGetNamespace(self.get());
1866 std::string repr("<DialectDescriptor ");
1867 repr.append(ns.data, ns.length);
1868 repr.append(">");
1869 return repr;
1870 });
1871
1872 //----------------------------------------------------------------------------
1873 // Mapping of PyDialects
1874 //----------------------------------------------------------------------------
1875 py::class_<PyDialects>(m, "Dialects")
1876 .def("__getitem__",
1877 [=](PyDialects &self, std::string keyName) {
1878 MlirDialect dialect =
1879 self.getDialectForKey(keyName, /*attrError=*/false);
1880 py::object descriptor =
1881 py::cast(PyDialectDescriptor{self.getContext(), dialect});
1882 return createCustomDialectWrapper(keyName, std::move(descriptor));
1883 })
1884 .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1885 MlirDialect dialect =
1886 self.getDialectForKey(attrName, /*attrError=*/true);
1887 py::object descriptor =
1888 py::cast(PyDialectDescriptor{self.getContext(), dialect});
1889 return createCustomDialectWrapper(attrName, std::move(descriptor));
1890 });
1891
1892 //----------------------------------------------------------------------------
1893 // Mapping of PyDialect
1894 //----------------------------------------------------------------------------
1895 py::class_<PyDialect>(m, "Dialect")
1896 .def(py::init<py::object>(), "descriptor")
1897 .def_property_readonly(
1898 "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1899 .def("__repr__", [](py::object self) {
1900 auto clazz = self.attr("__class__");
1901 return py::str("<Dialect ") +
1902 self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1903 clazz.attr("__module__") + py::str(".") +
1904 clazz.attr("__name__") + py::str(")>");
1905 });
1906
1907 //----------------------------------------------------------------------------
1908 // Mapping of Location
1909 //----------------------------------------------------------------------------
1910 py::class_<PyLocation>(m, "Location")
1911 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1912 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1913 .def("__enter__", &PyLocation::contextEnter)
1914 .def("__exit__", &PyLocation::contextExit)
1915 .def("__eq__",
1916 [](PyLocation &self, PyLocation &other) -> bool {
1917 return mlirLocationEqual(self, other);
1918 })
1919 .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1920 .def_property_readonly_static(
1921 "current",
1922 [](py::object & /*class*/) {
1923 auto *loc = PyThreadContextEntry::getDefaultLocation();
1924 if (!loc)
1925 throw SetPyError(PyExc_ValueError, "No current Location");
1926 return loc;
1927 },
1928 "Gets the Location bound to the current thread or raises ValueError")
1929 .def_static(
1930 "unknown",
1931 [](DefaultingPyMlirContext context) {
1932 return PyLocation(context->getRef(),
1933 mlirLocationUnknownGet(context->get()));
1934 },
1935 py::arg("context") = py::none(),
1936 "Gets a Location representing an unknown location")
1937 .def_static(
1938 "file",
1939 [](std::string filename, int line, int col,
1940 DefaultingPyMlirContext context) {
1941 return PyLocation(
1942 context->getRef(),
1943 mlirLocationFileLineColGet(
1944 context->get(), toMlirStringRef(filename), line, col));
1945 },
1946 py::arg("filename"), py::arg("line"), py::arg("col"),
1947 py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1948 .def_property_readonly(
1949 "context",
1950 [](PyLocation &self) { return self.getContext().getObject(); },
1951 "Context that owns the Location")
1952 .def("__repr__", [](PyLocation &self) {
1953 PyPrintAccumulator printAccum;
1954 mlirLocationPrint(self, printAccum.getCallback(),
1955 printAccum.getUserData());
1956 return printAccum.join();
1957 });
1958
1959 //----------------------------------------------------------------------------
1960 // Mapping of Module
1961 //----------------------------------------------------------------------------
1962 py::class_<PyModule>(m, "Module")
1963 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1964 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1965 .def_static(
1966 "parse",
1967 [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1968 MlirModule module = mlirModuleCreateParse(
1969 context->get(), toMlirStringRef(moduleAsm));
1970 // TODO: Rework error reporting once diagnostic engine is exposed
1971 // in C API.
1972 if (mlirModuleIsNull(module)) {
1973 throw SetPyError(
1974 PyExc_ValueError,
1975 "Unable to parse module assembly (see diagnostics)");
1976 }
1977 return PyModule::forModule(module).releaseObject();
1978 },
1979 py::arg("asm"), py::arg("context") = py::none(),
1980 kModuleParseDocstring)
1981 .def_static(
1982 "create",
1983 [](DefaultingPyLocation loc) {
1984 MlirModule module = mlirModuleCreateEmpty(loc);
1985 return PyModule::forModule(module).releaseObject();
1986 },
1987 py::arg("loc") = py::none(), "Creates an empty module")
1988 .def_property_readonly(
1989 "context",
1990 [](PyModule &self) { return self.getContext().getObject(); },
1991 "Context that created the Module")
1992 .def_property_readonly(
1993 "operation",
1994 [](PyModule &self) {
1995 return PyOperation::forOperation(self.getContext(),
1996 mlirModuleGetOperation(self.get()),
1997 self.getRef().releaseObject())
1998 .releaseObject();
1999 },
2000 "Accesses the module as an operation")
2001 .def_property_readonly(
2002 "body",
2003 [](PyModule &self) {
2004 PyOperationRef module_op = PyOperation::forOperation(
2005 self.getContext(), mlirModuleGetOperation(self.get()),
2006 self.getRef().releaseObject());
2007 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2008 return returnBlock;
2009 },
2010 "Return the block for this module")
2011 .def(
2012 "dump",
2013 [](PyModule &self) {
2014 mlirOperationDump(mlirModuleGetOperation(self.get()));
2015 },
2016 kDumpDocstring)
2017 .def(
2018 "__str__",
2019 [](PyModule &self) {
2020 MlirOperation operation = mlirModuleGetOperation(self.get());
2021 PyPrintAccumulator printAccum;
2022 mlirOperationPrint(operation, printAccum.getCallback(),
2023 printAccum.getUserData());
2024 return printAccum.join();
2025 },
2026 kOperationStrDunderDocstring);
2027
2028 //----------------------------------------------------------------------------
2029 // Mapping of Operation.
2030 //----------------------------------------------------------------------------
2031 py::class_<PyOperationBase>(m, "_OperationBase")
2032 .def("__eq__",
2033 [](PyOperationBase &self, PyOperationBase &other) {
2034 return &self.getOperation() == &other.getOperation();
2035 })
2036 .def("__eq__",
2037 [](PyOperationBase &self, py::object other) { return false; })
2038 .def_property_readonly("attributes",
2039 [](PyOperationBase &self) {
2040 return PyOpAttributeMap(
2041 self.getOperation().getRef());
2042 })
2043 .def_property_readonly("operands",
2044 [](PyOperationBase &self) {
2045 return PyOpOperandList(
2046 self.getOperation().getRef());
2047 })
2048 .def_property_readonly("regions",
2049 [](PyOperationBase &self) {
2050 return PyRegionList(
2051 self.getOperation().getRef());
2052 })
2053 .def_property_readonly(
2054 "results",
2055 [](PyOperationBase &self) {
2056 return PyOpResultList(self.getOperation().getRef());
2057 },
2058 "Returns the list of Operation results.")
2059 .def_property_readonly(
2060 "result",
2061 [](PyOperationBase &self) {
2062 auto &operation = self.getOperation();
2063 auto numResults = mlirOperationGetNumResults(operation);
2064 if (numResults != 1) {
2065 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2066 throw SetPyError(
2067 PyExc_ValueError,
2068 Twine("Cannot call .result on operation ") +
2069 StringRef(name.data, name.length) + " which has " +
2070 Twine(numResults) +
2071 " results (it is only valid for operations with a "
2072 "single result)");
2073 }
2074 return PyOpResult(operation.getRef(),
2075 mlirOperationGetResult(operation, 0));
2076 },
2077 "Shortcut to get an op result if it has only one (throws an error "
2078 "otherwise).")
2079 .def("__iter__",
2080 [](PyOperationBase &self) {
2081 return PyRegionIterator(self.getOperation().getRef());
2082 })
2083 .def(
2084 "__str__",
2085 [](PyOperationBase &self) {
2086 return self.getAsm(/*binary=*/false,
2087 /*largeElementsLimit=*/llvm::None,
2088 /*enableDebugInfo=*/false,
2089 /*prettyDebugInfo=*/false,
2090 /*printGenericOpForm=*/false,
2091 /*useLocalScope=*/false);
2092 },
2093 "Returns the assembly form of the operation.")
2094 .def("print", &PyOperationBase::print,
2095 // Careful: Lots of arguments must match up with print method.
2096 py::arg("file") = py::none(), py::arg("binary") = false,
2097 py::arg("large_elements_limit") = py::none(),
2098 py::arg("enable_debug_info") = false,
2099 py::arg("pretty_debug_info") = false,
2100 py::arg("print_generic_op_form") = false,
2101 py::arg("use_local_scope") = false, kOperationPrintDocstring)
2102 .def("get_asm", &PyOperationBase::getAsm,
2103 // Careful: Lots of arguments must match up with get_asm method.
2104 py::arg("binary") = false,
2105 py::arg("large_elements_limit") = py::none(),
2106 py::arg("enable_debug_info") = false,
2107 py::arg("pretty_debug_info") = false,
2108 py::arg("print_generic_op_form") = false,
2109 py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2110 .def(
2111 "verify",
2112 [](PyOperationBase &self) {
2113 return mlirOperationVerify(self.getOperation());
2114 },
2115 "Verify the operation and return true if it passes, false if it "
2116 "fails.");
2117
2118 py::class_<PyOperation, PyOperationBase>(m, "Operation")
2119 .def_static("create", &PyOperation::create, py::arg("name"),
2120 py::arg("results") = py::none(),
2121 py::arg("operands") = py::none(),
2122 py::arg("attributes") = py::none(),
2123 py::arg("successors") = py::none(), py::arg("regions") = 0,
2124 py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2125 kOperationCreateDocstring)
2126 .def_property_readonly("parent",
2127 [](PyOperation &self) {
2128 return self.getParentOperation().getObject();
2129 })
2130 .def("erase", &PyOperation::erase)
2131 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2132 &PyOperation::getCapsule)
2133 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2134 .def_property_readonly("name",
2135 [](PyOperation &self) {
2136 self.checkValid();
2137 MlirOperation operation = self.get();
2138 MlirStringRef name = mlirIdentifierStr(
2139 mlirOperationGetName(operation));
2140 return py::str(name.data, name.length);
2141 })
2142 .def_property_readonly(
2143 "context",
2144 [](PyOperation &self) {
2145 self.checkValid();
2146 return self.getContext().getObject();
2147 },
2148 "Context that owns the Operation")
2149 .def_property_readonly("opview", &PyOperation::createOpView);
2150
2151 auto opViewClass =
2152 py::class_<PyOpView, PyOperationBase>(m, "OpView")
2153 .def(py::init<py::object>())
2154 .def_property_readonly("operation", &PyOpView::getOperationObject)
2155 .def_property_readonly(
2156 "context",
2157 [](PyOpView &self) {
2158 return self.getOperation().getContext().getObject();
2159 },
2160 "Context that owns the Operation")
2161 .def("__str__", [](PyOpView &self) {
2162 return py::str(self.getOperationObject());
2163 });
2164 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2165 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2166 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2167 opViewClass.attr("build_generic") = classmethod(
2168 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2169 py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2170 py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2171 py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2172 "Builds a specific, generated OpView based on class level attributes.");
2173
2174 //----------------------------------------------------------------------------
2175 // Mapping of PyRegion.
2176 //----------------------------------------------------------------------------
2177 py::class_<PyRegion>(m, "Region")
2178 .def_property_readonly(
2179 "blocks",
2180 [](PyRegion &self) {
2181 return PyBlockList(self.getParentOperation(), self.get());
2182 },
2183 "Returns a forward-optimized sequence of blocks.")
2184 .def(
2185 "__iter__",
2186 [](PyRegion &self) {
2187 self.checkValid();
2188 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2189 return PyBlockIterator(self.getParentOperation(), firstBlock);
2190 },
2191 "Iterates over blocks in the region.")
2192 .def("__eq__",
2193 [](PyRegion &self, PyRegion &other) {
2194 return self.get().ptr == other.get().ptr;
2195 })
2196 .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2197
2198 //----------------------------------------------------------------------------
2199 // Mapping of PyBlock.
2200 //----------------------------------------------------------------------------
2201 py::class_<PyBlock>(m, "Block")
2202 .def_property_readonly(
2203 "arguments",
2204 [](PyBlock &self) {
2205 return PyBlockArgumentList(self.getParentOperation(), self.get());
2206 },
2207 "Returns a list of block arguments.")
2208 .def_property_readonly(
2209 "operations",
2210 [](PyBlock &self) {
2211 return PyOperationList(self.getParentOperation(), self.get());
2212 },
2213 "Returns a forward-optimized sequence of operations.")
2214 .def(
2215 "__iter__",
2216 [](PyBlock &self) {
2217 self.checkValid();
2218 MlirOperation firstOperation =
2219 mlirBlockGetFirstOperation(self.get());
2220 return PyOperationIterator(self.getParentOperation(),
2221 firstOperation);
2222 },
2223 "Iterates over operations in the block.")
2224 .def("__eq__",
2225 [](PyBlock &self, PyBlock &other) {
2226 return self.get().ptr == other.get().ptr;
2227 })
2228 .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2229 .def(
2230 "__str__",
2231 [](PyBlock &self) {
2232 self.checkValid();
2233 PyPrintAccumulator printAccum;
2234 mlirBlockPrint(self.get(), printAccum.getCallback(),
2235 printAccum.getUserData());
2236 return printAccum.join();
2237 },
2238 "Returns the assembly form of the block.");
2239
2240 //----------------------------------------------------------------------------
2241 // Mapping of PyInsertionPoint.
2242 //----------------------------------------------------------------------------
2243
2244 py::class_<PyInsertionPoint>(m, "InsertionPoint")
2245 .def(py::init<PyBlock &>(), py::arg("block"),
2246 "Inserts after the last operation but still inside the block.")
2247 .def("__enter__", &PyInsertionPoint::contextEnter)
2248 .def("__exit__", &PyInsertionPoint::contextExit)
2249 .def_property_readonly_static(
2250 "current",
2251 [](py::object & /*class*/) {
2252 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2253 if (!ip)
2254 throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2255 return ip;
2256 },
2257 "Gets the InsertionPoint bound to the current thread or raises "
2258 "ValueError if none has been set")
2259 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2260 "Inserts before a referenced operation.")
2261 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2262 py::arg("block"), "Inserts at the beginning of the block.")
2263 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2264 py::arg("block"), "Inserts before the block terminator.")
2265 .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2266 "Inserts an operation.");
2267
2268 //----------------------------------------------------------------------------
2269 // Mapping of PyAttribute.
2270 //----------------------------------------------------------------------------
2271 py::class_<PyAttribute>(m, "Attribute")
2272 // Delegate to the PyAttribute copy constructor, which will also lifetime
2273 // extend the backing context which owns the MlirAttribute.
2274 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2275 "Casts the passed attribute to the generic Attribute")
2276 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2277 &PyAttribute::getCapsule)
2278 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2279 .def_static(
2280 "parse",
2281 [](std::string attrSpec, DefaultingPyMlirContext context) {
2282 MlirAttribute type = mlirAttributeParseGet(
2283 context->get(), toMlirStringRef(attrSpec));
2284 // TODO: Rework error reporting once diagnostic engine is exposed
2285 // in C API.
2286 if (mlirAttributeIsNull(type)) {
2287 throw SetPyError(PyExc_ValueError,
2288 Twine("Unable to parse attribute: '") +
2289 attrSpec + "'");
2290 }
2291 return PyAttribute(context->getRef(), type);
2292 },
2293 py::arg("asm"), py::arg("context") = py::none(),
2294 "Parses an attribute from an assembly form")
2295 .def_property_readonly(
2296 "context",
2297 [](PyAttribute &self) { return self.getContext().getObject(); },
2298 "Context that owns the Attribute")
2299 .def_property_readonly("type",
2300 [](PyAttribute &self) {
2301 return PyType(self.getContext()->getRef(),
2302 mlirAttributeGetType(self));
2303 })
2304 .def(
2305 "get_named",
2306 [](PyAttribute &self, std::string name) {
2307 return PyNamedAttribute(self, std::move(name));
2308 },
2309 py::keep_alive<0, 1>(), "Binds a name to the attribute")
2310 .def("__eq__",
2311 [](PyAttribute &self, PyAttribute &other) { return self == other; })
2312 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2313 .def(
2314 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2315 kDumpDocstring)
2316 .def(
2317 "__str__",
2318 [](PyAttribute &self) {
2319 PyPrintAccumulator printAccum;
2320 mlirAttributePrint(self, printAccum.getCallback(),
2321 printAccum.getUserData());
2322 return printAccum.join();
2323 },
2324 "Returns the assembly form of the Attribute.")
2325 .def("__repr__", [](PyAttribute &self) {
2326 // Generally, assembly formats are not printed for __repr__ because
2327 // this can cause exceptionally long debug output and exceptions.
2328 // However, attribute values are generally considered useful and are
2329 // printed. This may need to be re-evaluated if debug dumps end up
2330 // being excessive.
2331 PyPrintAccumulator printAccum;
2332 printAccum.parts.append("Attribute(");
2333 mlirAttributePrint(self, printAccum.getCallback(),
2334 printAccum.getUserData());
2335 printAccum.parts.append(")");
2336 return printAccum.join();
2337 });
2338
2339 //----------------------------------------------------------------------------
2340 // Mapping of PyNamedAttribute
2341 //----------------------------------------------------------------------------
2342 py::class_<PyNamedAttribute>(m, "NamedAttribute")
2343 .def("__repr__",
2344 [](PyNamedAttribute &self) {
2345 PyPrintAccumulator printAccum;
2346 printAccum.parts.append("NamedAttribute(");
2347 printAccum.parts.append(
2348 mlirIdentifierStr(self.namedAttr.name).data);
2349 printAccum.parts.append("=");
2350 mlirAttributePrint(self.namedAttr.attribute,
2351 printAccum.getCallback(),
2352 printAccum.getUserData());
2353 printAccum.parts.append(")");
2354 return printAccum.join();
2355 })
2356 .def_property_readonly(
2357 "name",
2358 [](PyNamedAttribute &self) {
2359 return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2360 mlirIdentifierStr(self.namedAttr.name).length);
2361 },
2362 "The name of the NamedAttribute binding")
2363 .def_property_readonly(
2364 "attr",
2365 [](PyNamedAttribute &self) {
2366 // TODO: When named attribute is removed/refactored, also remove
2367 // this constructor (it does an inefficient table lookup).
2368 auto contextRef = PyMlirContext::forContext(
2369 mlirAttributeGetContext(self.namedAttr.attribute));
2370 return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2371 },
2372 py::keep_alive<0, 1>(),
2373 "The underlying generic attribute of the NamedAttribute binding");
2374
2375 //----------------------------------------------------------------------------
2376 // Mapping of PyType.
2377 //----------------------------------------------------------------------------
2378 py::class_<PyType>(m, "Type")
2379 // Delegate to the PyType copy constructor, which will also lifetime
2380 // extend the backing context which owns the MlirType.
2381 .def(py::init<PyType &>(), py::arg("cast_from_type"),
2382 "Casts the passed type to the generic Type")
2383 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2384 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2385 .def_static(
2386 "parse",
2387 [](std::string typeSpec, DefaultingPyMlirContext context) {
2388 MlirType type =
2389 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2390 // TODO: Rework error reporting once diagnostic engine is exposed
2391 // in C API.
2392 if (mlirTypeIsNull(type)) {
2393 throw SetPyError(PyExc_ValueError,
2394 Twine("Unable to parse type: '") + typeSpec +
2395 "'");
2396 }
2397 return PyType(context->getRef(), type);
2398 },
2399 py::arg("asm"), py::arg("context") = py::none(),
2400 kContextParseTypeDocstring)
2401 .def_property_readonly(
2402 "context", [](PyType &self) { return self.getContext().getObject(); },
2403 "Context that owns the Type")
2404 .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2405 .def("__eq__", [](PyType &self, py::object &other) { return false; })
2406 .def(
2407 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2408 .def(
2409 "__str__",
2410 [](PyType &self) {
2411 PyPrintAccumulator printAccum;
2412 mlirTypePrint(self, printAccum.getCallback(),
2413 printAccum.getUserData());
2414 return printAccum.join();
2415 },
2416 "Returns the assembly form of the type.")
2417 .def("__repr__", [](PyType &self) {
2418 // Generally, assembly formats are not printed for __repr__ because
2419 // this can cause exceptionally long debug output and exceptions.
2420 // However, types are an exception as they typically have compact
2421 // assembly forms and printing them is useful.
2422 PyPrintAccumulator printAccum;
2423 printAccum.parts.append("Type(");
2424 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2425 printAccum.parts.append(")");
2426 return printAccum.join();
2427 });
2428
2429 //----------------------------------------------------------------------------
2430 // Mapping of Value.
2431 //----------------------------------------------------------------------------
2432 py::class_<PyValue>(m, "Value")
2433 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2434 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2435 .def_property_readonly(
2436 "context",
2437 [](PyValue &self) { return self.getParentOperation()->getContext(); },
2438 "Context in which the value lives.")
2439 .def(
2440 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2441 kDumpDocstring)
2442 .def_property_readonly(
2443 "owner",
2444 [](PyValue &self) {
2445 assert(mlirOperationEqual(self.getParentOperation()->get(),
2446 mlirOpResultGetOwner(self.get())) &&
2447 "expected the owner of the value in Python to match that in "
2448 "the IR");
2449 return self.getParentOperation().getObject();
2450 })
2451 .def("__eq__",
2452 [](PyValue &self, PyValue &other) {
2453 return self.get().ptr == other.get().ptr;
2454 })
2455 .def("__eq__", [](PyValue &self, py::object other) { return false; })
2456 .def(
2457 "__str__",
2458 [](PyValue &self) {
2459 PyPrintAccumulator printAccum;
2460 printAccum.parts.append("Value(");
2461 mlirValuePrint(self.get(), printAccum.getCallback(),
2462 printAccum.getUserData());
2463 printAccum.parts.append(")");
2464 return printAccum.join();
2465 },
2466 kValueDunderStrDocstring)
2467 .def_property_readonly("type", [](PyValue &self) {
2468 return PyType(self.getParentOperation()->getContext(),
2469 mlirValueGetType(self.get()));
2470 });
2471 PyBlockArgument::bind(m);
2472 PyOpResult::bind(m);
2473
2474 // Container bindings.
2475 PyBlockArgumentList::bind(m);
2476 PyBlockIterator::bind(m);
2477 PyBlockList::bind(m);
2478 PyOperationIterator::bind(m);
2479 PyOperationList::bind(m);
2480 PyOpAttributeMap::bind(m);
2481 PyOpOperandList::bind(m);
2482 PyOpResultList::bind(m);
2483 PyRegionIterator::bind(m);
2484 PyRegionList::bind(m);
2485
2486 // Debug bindings.
2487 PyGlobalDebugFlag::bind(m);
2488 }
2489